1515"""Tokenization classes for Whisper."""
1616import json
1717import os
18+ import re
1819from functools import lru_cache
1920from typing import List , Optional , Tuple
2021
@@ -190,6 +191,7 @@ def __init__(
190191 self .english_spelling_normalizer = None
191192
192193 self .add_prefix_space = add_prefix_space
194+ self .timestamp_pat = re .compile (r"<\|(\d+\.\d+)\|>" )
193195
194196 self .language = language
195197 self .task = task
@@ -269,10 +271,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
269271 start_timestamp_position = sliced_tokens [0 ].item () - timestamp_begin
270272 end_timestamp_position = sliced_tokens [- 1 ].item () - timestamp_begin
271273 # strip timestamp tokens from the text output
272- sliced_tokens = self ._preprocess_token_ids (sliced_tokens , decode_with_timestamps = False )
274+ sliced_tokens = self ._preprocess_token_ids (sliced_tokens )
275+ text = self ._decode (sliced_tokens )
276+ text = self ._filter_timestamp_ids (text )
273277 offsets .append (
274278 {
275- "text" : self . _decode ( sliced_tokens ) ,
279+ "text" : text ,
276280 "timestamp" : (
277281 start_timestamp_position * time_precision ,
278282 end_timestamp_position * time_precision ,
@@ -296,9 +300,7 @@ def timestamp_ids(self, time_precision=0.02):
296300 return self .convert_tokens_to_ids ([("<|%.2f|>" % (i * time_precision )) for i in range (1500 + 1 )])
297301
298302 # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
299- def _preprocess_token_ids (
300- self , token_ids , skip_special_tokens : bool = False , decode_with_timestamps : bool = False , time_precision = 0.02
301- ):
303+ def _preprocess_token_ids (self , token_ids , skip_special_tokens : bool = False ):
302304 """
303305 Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
304306
@@ -308,24 +310,18 @@ def _preprocess_token_ids(
308310 skip_special_tokens (`bool`, *optional*, defaults to `False`):
309311 Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
310312 removed.
311- decode_with_timestamps (`bool`, *optional*, defaults to `False`):
312- Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
313- filtered out from the token ids.
314- time_precision (`float`, `optional`, defaults to 0.02):
315- The time ratio to convert from token to time.
316313 """
317314 if skip_special_tokens :
318315 prompt_token_id = self .convert_tokens_to_ids ("<|startofprev|>" )
319316 decoder_start_token_id = self .convert_tokens_to_ids ("<|startoftranscript|>" )
320317 token_ids = self ._strip_prompt (token_ids , prompt_token_id , decoder_start_token_id )
321318
322- if not decode_with_timestamps :
323- # filter timestamp tokens if they are contained in the vocab
324- timestamp_ids = self .timestamp_ids (time_precision = time_precision )
325- token_ids = [token for token in token_ids if token not in timestamp_ids ]
326-
327319 return token_ids
328320
321+ # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
322+ def _filter_timestamp_ids (self , token_ids ):
323+ return re .sub (self .timestamp_pat , "" , token_ids )
324+
329325 # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
330326 def decode (
331327 self ,
@@ -356,6 +352,8 @@ def decode(
356352 output_offsets (`bool`, *optional*, defaults to `False`):
357353 Whether or not to output the offsets of the tokens. This should only be set if the model predicted
358354 timestamps.
355+ time_precision (`float`, `optional`, defaults to 0.02):
356+ The time ratio to convert from token to time.
359357 decode_with_timestamps (`bool`, *optional*, defaults to `False`):
360358 Whether or not to decode with timestamps included in the raw text.
361359 Returns:
@@ -364,8 +362,6 @@ def decode(
364362 filtered_ids = self ._preprocess_token_ids (
365363 token_ids ,
366364 skip_special_tokens = skip_special_tokens ,
367- decode_with_timestamps = decode_with_timestamps ,
368- time_precision = time_precision ,
369365 )
370366
371367 text = super ().decode (
@@ -380,6 +376,9 @@ def decode(
380376 text = self ._decode_with_timestamps (
381377 filtered_ids , time_precision = time_precision , skip_special_tokens = skip_special_tokens
382378 )
379+ else :
380+ text = self ._filter_timestamp_ids (text )
381+
383382 # retrieve offsets
384383 if output_offsets :
385384 offsets = self ._compute_offsets (token_ids , time_precision = time_precision )
0 commit comments