Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb):
language="hh",
temperature=0.0)

# Expect audio too long: repeat the timeseries
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=buffer,
language="en",
temperature=0.0)

@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]

mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=buffer,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10


@pytest.mark.asyncio
Expand Down
227 changes: 145 additions & 82 deletions vllm/entrypoints/openai/serving_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
from collections.abc import AsyncGenerator
from math import ceil
from typing import Final, Optional, Union, cast

import numpy as np
from fastapi import Request

from vllm.config import ModelConfig
Expand Down Expand Up @@ -143,6 +145,8 @@
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
OVERLAP_CHUNK_SECOND = 1
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio


class OpenAIServingTranscription(OpenAIServing):
Expand Down Expand Up @@ -178,7 +182,7 @@ async def _preprocess_transcription(
self,
request: TranscriptionRequest,
audio_data: bytes,
) -> tuple[PromptType, float]:
) -> tuple[list[PromptType], float]:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
Expand Down Expand Up @@ -206,22 +210,22 @@ async def _preprocess_transcription(
y, sr = librosa.load(bytes_)

duration = librosa.get_duration(y=y, sr=sr)
if duration > self.max_audio_clip_s:
raise ValueError(
f"Maximum clip duration ({self.max_audio_clip_s}s) "
"exceeded.")

prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (y, sr),
chunks = [y] if duration < 30 else self._split_audio(y, sr)
prompts = []
for i, chunk in enumerate(chunks):
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (chunk, sr),
},
},
},
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
}
return cast(PromptType, prompt), duration
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
if i == 0 else ""
}
prompts.append(cast(PromptType, prompt))
return prompts, duration

# TODO (varun) : Make verbose response work !
async def create_transcription(
Expand Down Expand Up @@ -268,7 +272,7 @@ async def create_transcription(
"Currently do not support PromptAdapter for Transcription."
)

prompt, duration_s = await self._preprocess_transcription(
prompts, duration_s = await self._preprocess_transcription(
request=request,
audio_data=audio_data,
)
Expand All @@ -277,7 +281,8 @@ async def create_transcription(
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))

result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
None]]] = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
Expand All @@ -288,32 +293,36 @@ async def create_transcription(

self._log_inputs(
request_id,
prompt['decoder_prompt'], # type: ignore
prompts[0]['decoder_prompt'], # type: ignore
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)

result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
list_result_generator = [
self.engine_client.generate(
prompt,
sampling_params,
request_id,
) for prompt in prompts
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

if request.stream:
return self.transcription_stream_generator(request,
result_generator,
list_result_generator,
request_id,
request_metadata,
duration_s)
# Non-streaming response.
try:
assert result_generator is not None
async for op in result_generator:
result = op
return TranscriptionResponse(text=result.outputs[0].text)
assert list_result_generator is not None
text = ""
for result_generator in list_result_generator:
async for op in result_generator:
text += op.outputs[0].text
return TranscriptionResponse(text=text)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
Expand All @@ -322,7 +331,7 @@ async def create_transcription(

async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
created_time = int(time.time())
Expand All @@ -335,60 +344,65 @@ async def transcription_stream_generator(
include_usage = request.stream_include_usage \
if request.stream_include_usage else False
include_continuous_usage = request.stream_continuous_usage_stats\
if include_usage and request.stream_continuous_usage_stats\
else False
if include_usage and request.stream_continuous_usage_stats\
else False

try:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token is not specified.
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder prompts directly
# at least not to Whisper. One indicator of the encoder
# amount of processing is the log-mel spectogram length.
num_prompt_tokens += ceil(audio_duration_s *
self.model_sr / self.hop_length)

# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).

# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]

delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)

if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)

chunk = TranscriptionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token
# is not specified.
num_prompt_tokens = max(
len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
num_prompt_tokens += ceil(
audio_duration_s * self.model_sr / self.hop_length)

# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).

# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]

delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)

if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)

chunk = TranscriptionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
Expand Down Expand Up @@ -422,3 +436,52 @@ async def transcription_stream_generator(
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"

def _split_audio(self, audio_data: np.ndarray,
sample_rate: int) -> list[np.ndarray]:
chunk_size = sample_rate * self.max_audio_clip_s
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break

# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start,
search_end)

# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks

def _find_split_point(self, wav: np.ndarray, start_idx: int,
end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]

# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
for i in range(0,
len(segment) - MIN_ENERGY_WINDOW_SIZE,
MIN_ENERGY_WINDOW_SIZE):
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
energy = (window**2).mean()**0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx