Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -546,6 +547,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
offsets.append(
{
"text": self._decode(sliced_tokens),
Expand All @@ -559,6 +562,47 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

return offsets

@lru_cache
def timestamp_ids(self, time_precision=0.02):
"""
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.

Args:
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.

Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

def decode(
self,
token_ids,
Expand Down Expand Up @@ -593,33 +637,40 @@ def decode(
Returns:
`str`: The decoded sentence.
"""
text = super().decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
**kwargs,
)
if decode_with_timestamps:
# legacy method to decode timestamps when not included in the tokenizer vocabulary
text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets
if output_offsets:
offsets = None
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
Copy link
Contributor Author

@sanchit-gandhi sanchit-gandhi Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to use token_ids here, not filtered_ids, since we need the timestamp ids to be present so that we can compute the offsets

We later strip the timestamp ids from the chunk outputs in the _compute_offsets method

return {"text": text, "offsets": offsets}
return text

def _decode(
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
decode_with_timestamps: bool = False,
**kwargs,
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
Expand Down
65 changes: 57 additions & 8 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -255,6 +256,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
offsets.append(
{
"text": self._decode(sliced_tokens),
Expand All @@ -268,6 +271,49 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

return offsets

@lru_cache
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.timestamp_ids
def timestamp_ids(self, time_precision=0.02):
"""
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.

Args:
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.

Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode(
self,
Expand Down Expand Up @@ -303,29 +349,32 @@ def decode(
Returns:
`str`: The decoded sentence.
"""
text = super().decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
**kwargs,
)
if decode_with_timestamps:
# legacy method to decode timestamps when not included in the tokenizer vocabulary
text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets
if output_offsets:
offsets = None
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
return {"text": text, "offsets": offsets}
return text

def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
if kwargs["skip_special_tokens"]:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)

text = super()._decode(*args, **kwargs)

if normalize:
Expand Down
7 changes: 2 additions & 5 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ def test_convert_token_and_id(self):
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)

@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())

self.assertEqual(vocab_keys[0], "!")
self.assertEqual(vocab_keys[1], '"')
self.assertEqual(vocab_keys[-1], "<|notimestamps|>")
self.assertEqual(len(vocab_keys), 50364)
self.assertEqual(vocab_keys[-1], "<|30.00|>")
self.assertEqual(len(vocab_keys), 51865)

def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
Expand Down Expand Up @@ -117,7 +116,6 @@ def test_tokenizer_integration(self):
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
)

@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_output_offsets(self):
tokenizer = self.get_tokenizer()
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
Expand Down Expand Up @@ -400,7 +398,6 @@ def test_batch_encoding_decoding(self):
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)

@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_offset_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# fmt: off
Expand Down