Skip to content

Whisper Prompting max_new_tokens #25422

@Helene-Maxcici

Description

@Helene-Maxcici

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (False)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.1 (cpu)
  • Jax version: 0.4.14
  • JaxLib version: 0.4.14
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Bug Related

We keep model.config.max_length=448. The error happens when:

  1. len(prompt_ids) + max_new_tokens > model.config.max_length + 1
  2. We fix max_new_tokens in model.generate()
  3. The length of the generated new tokens reaches its maximum. This mainly occurs when Whisper fails to predict the eos token and starts repeating some sequence of tokens.
from transformers import (WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration)
from datasets import load_dataset

# Load dataset
fleurs_fr = load_dataset("google/fleurs", "fr_fr", split="test")

# Load Processor + Model
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# Chosen a sample that causes repetition
i = 512
input_speech = fleurs_fr[i]["audio"]["array"]
sr = fleurs_fr[i]["audio"]["sampling_rate"]

# Create big enough prompt text
# It should be sliced inside generate anyway
prompt_text = " bien," * 113
prompt_ids = processor.get_prompt_ids(prompt_text)

# Generate
input_features = processor(input_speech, return_tensors="pt",
                            sampling_rate=16e3).input_features

output_with_prompt = model.generate(input_features,
                                    language="fr",
                                    task="transcribe",
                                    prompt_ids= prompt_ids,
                                    max_new_tokens=224)

Output:

IndexError                                Traceback (most recent call last)
[<ipython-input-4-3420d576291f>](https://localhost:8080/#) in <cell line: 4>()
      2                             sampling_rate=16e3).input_features
      3 
----> 4 output_with_prompt = model.generate(input_features,
      5                                     language="fr",
      6                                     task="transcribe",

3 frames
[/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, return_token_timestamps, **kwargs)
   1747                 )
   1748 
-> 1749         outputs = super().generate(
   1750             inputs,
   1751             generation_config,

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1536 
   1537             # 11. run greedy search
-> 1538             return self.greedy_search(
   1539                 input_ids,
   1540                 logits_processor=logits_processor,

[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2370                 continue  # don't waste resources running the code we don't need
   2371 
-> 2372             next_token_logits = outputs.logits[:, -1, :]
   2373 
   2374             # pre-process distribution

IndexError: index -1 is out of bounds for dimension 1 with size 0

The bug might be caused by no condition set on max_new_tokens inside the generate() function, which might be a general bug for generation and not only for prompting.

Note

Also, as I was reading the code I noticed this line:
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]

It slices the text prompt ids and takes (self.config.max_length // 2 + 1) tokens instead of (self.config.max_length // 2 - 1) as taken in the original code of Whisper here.

Expected behavior

  • Clear warning or error about surpassing the model.max_length.
  • Being able to set max_new_tokens=224 ( = max_length // 2) during prompting.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions