diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 1cb0a39df513..8117e774951e 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb): language="hh", temperature=0.0) - # Expect audio too long: repeat the timeseries - mary_had_lamb.seek(0) - audio, sr = librosa.load(mary_had_lamb) - repeated_audio = np.tile(audio, 10) - # Repeated audio to buffer - buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') - buffer.seek(0) - with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=model_name, - file=buffer, - language="en", - temperature=0.0) + +@pytest.mark.asyncio +async def test_long_audio_request(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=buffer, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert out.count("Mary had a little lamb") == 10 @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index f667c7e9b3a9..60d66434ea5a 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -2,11 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import io +import math import time from collections.abc import AsyncGenerator from math import ceil from typing import Final, Optional, Union, cast +import numpy as np from fastapi import Request from vllm.config import ModelConfig @@ -143,6 +145,8 @@ # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable MAX_AUDIO_CLIP_FILESIZE_MB = 25 +OVERLAP_CHUNK_SECOND = 1 +MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio class OpenAIServingTranscription(OpenAIServing): @@ -178,7 +182,7 @@ async def _preprocess_transcription( self, request: TranscriptionRequest, audio_data: bytes, - ) -> tuple[PromptType, float]: + ) -> tuple[list[PromptType], float]: # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See @@ -206,22 +210,22 @@ async def _preprocess_transcription( y, sr = librosa.load(bytes_) duration = librosa.get_duration(y=y, sr=sr) - if duration > self.max_audio_clip_s: - raise ValueError( - f"Maximum clip duration ({self.max_audio_clip_s}s) " - "exceeded.") - - prompt = { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": (y, sr), + chunks = [y] if duration < 30 else self._split_audio(y, sr) + prompts = [] + for i, chunk in enumerate(chunks): + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (chunk, sr), + }, }, - }, - "decoder_prompt": - f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" - } - return cast(PromptType, prompt), duration + "decoder_prompt": + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" + if i == 0 else "" + } + prompts.append(cast(PromptType, prompt)) + return prompts, duration # TODO (varun) : Make verbose response work ! async def create_transcription( @@ -268,7 +272,7 @@ async def create_transcription( "Currently do not support PromptAdapter for Transcription." ) - prompt, duration_s = await self._preprocess_transcription( + prompts, duration_s = await self._preprocess_transcription( request=request, audio_data=audio_data, ) @@ -277,7 +281,8 @@ async def create_transcription( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, + None]]] = None try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a @@ -288,32 +293,36 @@ async def create_transcription( self._log_inputs( request_id, - prompt['decoder_prompt'], # type: ignore + prompts[0]['decoder_prompt'], # type: ignore params=sampling_params, lora_request=None, prompt_adapter_request=None) - result_generator = self.engine_client.generate( - prompt, - sampling_params, - request_id, - ) + list_result_generator = [ + self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) for prompt in prompts + ] except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) if request.stream: return self.transcription_stream_generator(request, - result_generator, + list_result_generator, request_id, request_metadata, duration_s) # Non-streaming response. try: - assert result_generator is not None - async for op in result_generator: - result = op - return TranscriptionResponse(text=result.outputs[0].text) + assert list_result_generator is not None + text = "" + for result_generator in list_result_generator: + async for op in result_generator: + text += op.outputs[0].text + return TranscriptionResponse(text=text) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: @@ -322,7 +331,7 @@ async def create_transcription( async def transcription_stream_generator( self, request: TranscriptionRequest, - result_generator: AsyncGenerator[RequestOutput, None], + list_result_generator: list[AsyncGenerator[RequestOutput, None]], request_id: str, request_metadata: RequestResponseMetadata, audio_duration_s: float) -> AsyncGenerator[str, None]: created_time = int(time.time()) @@ -335,60 +344,65 @@ async def transcription_stream_generator( include_usage = request.stream_include_usage \ if request.stream_include_usage else False include_continuous_usage = request.stream_continuous_usage_stats\ - if include_usage and request.stream_continuous_usage_stats\ - else False + if include_usage and request.stream_continuous_usage_stats\ + else False try: - async for res in result_generator: - # On first result. - if res.prompt_token_ids is not None: - # Do not account the 4-tokens `<|startoftranscript|>..` - # Could be negative when language token is not specified. - num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0) - # NOTE(NickLucche) user can't pass encoder prompts directly - # at least not to Whisper. One indicator of the encoder - # amount of processing is the log-mel spectogram length. - num_prompt_tokens += ceil(audio_duration_s * - self.model_sr / self.hop_length) - - # We need to do it here, because if there are exceptions in - # the result_generator, it needs to be sent as the FIRST - # response (by the try...catch). - - # Just one output (n=1) supported. - assert len(res.outputs) == 1 - output = res.outputs[0] - - delta_message = DeltaMessage(content=output.text) - completion_tokens += len(output.token_ids) - - if output.finish_reason is None: - # Still generating, send delta update. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message) - else: - # Model is finished generating. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - - chunk = TranscriptionStreamResponse(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - - # handle usage stats if requested & if continuous - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) - - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" + for result_generator in list_result_generator: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token + # is not specified. + num_prompt_tokens = max( + len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + num_prompt_tokens += ceil( + audio_duration_s * self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message) + else: + # Model is finished generating. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = TranscriptionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" # Once the final token is handled, if stream_options.include_usage # is sent, send the usage. @@ -422,3 +436,52 @@ async def transcription_stream_generator( yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" + + def _split_audio(self, audio_data: np.ndarray, + sample_rate: int) -> list[np.ndarray]: + chunk_size = sample_rate * self.max_audio_clip_s + overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunks = [] + i = 0 + while i < audio_data.shape[-1]: + if i + chunk_size >= audio_data.shape[-1]: + # handle last chunk + chunks.append(audio_data[..., i:]) + break + + # Find the best split point in the overlap region + search_start = i + chunk_size - overlap_size + search_end = min(i + chunk_size, audio_data.shape[-1]) + split_point = self._find_split_point(audio_data, search_start, + search_end) + + # Extract chunk up to the split point + chunks.append(audio_data[..., i:split_point]) + i = split_point + return chunks + + def _find_split_point(self, wav: np.ndarray, start_idx: int, + end_idx: int) -> int: + """Find the best point to split audio by + looking for silence or low amplitude. + Args: + wav: Audio tensor [1, T] + start_idx: Start index of search region + end_idx: End index of search region + Returns: + Index of best splitting point + """ + segment = wav[start_idx:end_idx] + + # Calculate RMS energy in small windows + min_energy = math.inf + quietest_idx = 0 + for i in range(0, + len(segment) - MIN_ENERGY_WINDOW_SIZE, + MIN_ENERGY_WINDOW_SIZE): + window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + energy = (window**2).mean()**0.5 + if energy < min_energy: + quietest_idx = i + start_idx + min_energy = energy + return quietest_idx