Skip to content

Commit 211f93a

Browse files
[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299)
make decoding faster
1 parent 4e931a8 commit 211f93a

File tree

2 files changed

+30
-34
lines changed

2 files changed

+30
-34
lines changed

src/transformers/models/whisper/tokenization_whisper.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314

315315
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
316316
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
317+
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
317318

318319
self.language = language
319320
super().__init__(
@@ -560,10 +561,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
560561
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
561562
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
562563
# strip timestamp tokens from the text output
563-
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
564+
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
565+
text = self._decode(sliced_tokens)
566+
text = self._filter_timestamp_ids(text)
564567
offsets.append(
565568
{
566-
"text": self._decode(sliced_tokens),
569+
"text": text,
567570
"timestamp": (
568571
start_timestamp_position * time_precision,
569572
end_timestamp_position * time_precision,
@@ -585,9 +588,7 @@ def timestamp_ids(self, time_precision=0.02):
585588
"""
586589
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
587590

588-
def _preprocess_token_ids(
589-
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
590-
):
591+
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
591592
"""
592593
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
593594
@@ -597,24 +598,17 @@ def _preprocess_token_ids(
597598
skip_special_tokens (`bool`, *optional*, defaults to `False`):
598599
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
599600
removed.
600-
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
601-
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
602-
filtered out from the token ids.
603-
time_precision (`float`, `optional`, defaults to 0.02):
604-
The time ratio to convert from token to time.
605601
"""
606602
if skip_special_tokens:
607603
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
608604
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
609605
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
610606

611-
if not decode_with_timestamps:
612-
# filter timestamp tokens if they are contained in the vocab
613-
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
614-
token_ids = [token for token in token_ids if token not in timestamp_ids]
615-
616607
return token_ids
617608

609+
def _filter_timestamp_ids(self, token_ids):
610+
return re.sub(self.timestamp_pat, "", token_ids)
611+
618612
def decode(
619613
self,
620614
token_ids,
@@ -644,6 +638,8 @@ def decode(
644638
output_offsets (`bool`, *optional*, defaults to `False`):
645639
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
646640
timestamps.
641+
time_precision (`float`, `optional`, defaults to 0.02):
642+
The time ratio to convert from token to time.
647643
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
648644
Whether or not to decode with timestamps included in the raw text.
649645
Returns:
@@ -652,8 +648,6 @@ def decode(
652648
filtered_ids = self._preprocess_token_ids(
653649
token_ids,
654650
skip_special_tokens=skip_special_tokens,
655-
decode_with_timestamps=decode_with_timestamps,
656-
time_precision=time_precision,
657651
)
658652

659653
text = super().decode(
@@ -668,6 +662,9 @@ def decode(
668662
text = self._decode_with_timestamps(
669663
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
670664
)
665+
else:
666+
text = self._filter_timestamp_ids(text)
667+
671668
# retrieve offsets
672669
if output_offsets:
673670
offsets = self._compute_offsets(token_ids, time_precision=time_precision)

src/transformers/models/whisper/tokenization_whisper_fast.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tokenization classes for Whisper."""
1616
import json
1717
import os
18+
import re
1819
from functools import lru_cache
1920
from typing import List, Optional, Tuple
2021

@@ -190,6 +191,7 @@ def __init__(
190191
self.english_spelling_normalizer = None
191192

192193
self.add_prefix_space = add_prefix_space
194+
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
193195

194196
self.language = language
195197
self.task = task
@@ -269,10 +271,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
269271
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
270272
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
271273
# strip timestamp tokens from the text output
272-
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
274+
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
275+
text = self._decode(sliced_tokens)
276+
text = self._filter_timestamp_ids(text)
273277
offsets.append(
274278
{
275-
"text": self._decode(sliced_tokens),
279+
"text": text,
276280
"timestamp": (
277281
start_timestamp_position * time_precision,
278282
end_timestamp_position * time_precision,
@@ -296,9 +300,7 @@ def timestamp_ids(self, time_precision=0.02):
296300
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
297301

298302
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
299-
def _preprocess_token_ids(
300-
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
301-
):
303+
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
302304
"""
303305
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
304306
@@ -308,24 +310,18 @@ def _preprocess_token_ids(
308310
skip_special_tokens (`bool`, *optional*, defaults to `False`):
309311
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
310312
removed.
311-
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
312-
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
313-
filtered out from the token ids.
314-
time_precision (`float`, `optional`, defaults to 0.02):
315-
The time ratio to convert from token to time.
316313
"""
317314
if skip_special_tokens:
318315
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
319316
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
320317
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
321318

322-
if not decode_with_timestamps:
323-
# filter timestamp tokens if they are contained in the vocab
324-
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
325-
token_ids = [token for token in token_ids if token not in timestamp_ids]
326-
327319
return token_ids
328320

321+
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
322+
def _filter_timestamp_ids(self, token_ids):
323+
return re.sub(self.timestamp_pat, "", token_ids)
324+
329325
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
330326
def decode(
331327
self,
@@ -356,6 +352,8 @@ def decode(
356352
output_offsets (`bool`, *optional*, defaults to `False`):
357353
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
358354
timestamps.
355+
time_precision (`float`, `optional`, defaults to 0.02):
356+
The time ratio to convert from token to time.
359357
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
360358
Whether or not to decode with timestamps included in the raw text.
361359
Returns:
@@ -364,8 +362,6 @@ def decode(
364362
filtered_ids = self._preprocess_token_ids(
365363
token_ids,
366364
skip_special_tokens=skip_special_tokens,
367-
decode_with_timestamps=decode_with_timestamps,
368-
time_precision=time_precision,
369365
)
370366

371367
text = super().decode(
@@ -380,6 +376,9 @@ def decode(
380376
text = self._decode_with_timestamps(
381377
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
382378
)
379+
else:
380+
text = self._filter_timestamp_ids(text)
381+
383382
# retrieve offsets
384383
if output_offsets:
385384
offsets = self._compute_offsets(token_ids, time_precision=time_precision)

0 commit comments

Comments
 (0)