-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Closed
Description
System Info
transformersversion: 4.31.0- Platform: Linux-5.15.109+-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.2
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1+cu118 (False)
- Tensorflow version (GPU?): 2.12.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.7.1 (cpu)
- Jax version: 0.4.14
- JaxLib version: 0.4.14
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Bug Related
We keep model.config.max_length=448. The error happens when:
len(prompt_ids) + max_new_tokens > model.config.max_length + 1- We fix
max_new_tokensinmodel.generate() - The length of the generated new tokens reaches its maximum. This mainly occurs when Whisper fails to predict the
eostoken and starts repeating some sequence of tokens.
from transformers import (WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration)
from datasets import load_dataset
# Load dataset
fleurs_fr = load_dataset("google/fleurs", "fr_fr", split="test")
# Load Processor + Model
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# Chosen a sample that causes repetition
i = 512
input_speech = fleurs_fr[i]["audio"]["array"]
sr = fleurs_fr[i]["audio"]["sampling_rate"]
# Create big enough prompt text
# It should be sliced inside generate anyway
prompt_text = " bien," * 113
prompt_ids = processor.get_prompt_ids(prompt_text)
# Generate
input_features = processor(input_speech, return_tensors="pt",
sampling_rate=16e3).input_features
output_with_prompt = model.generate(input_features,
language="fr",
task="transcribe",
prompt_ids= prompt_ids,
max_new_tokens=224)Output:
IndexError Traceback (most recent call last)
[<ipython-input-4-3420d576291f>](https://localhost:8080/#) in <cell line: 4>()
2 sampling_rate=16e3).input_features
3
----> 4 output_with_prompt = model.generate(input_features,
5 language="fr",
6 task="transcribe",
3 frames
[/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, return_token_timestamps, **kwargs)
1747 )
1748
-> 1749 outputs = super().generate(
1750 inputs,
1751 generation_config,
[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1536
1537 # 11. run greedy search
-> 1538 return self.greedy_search(
1539 input_ids,
1540 logits_processor=logits_processor,
[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2370 continue # don't waste resources running the code we don't need
2371
-> 2372 next_token_logits = outputs.logits[:, -1, :]
2373
2374 # pre-process distribution
IndexError: index -1 is out of bounds for dimension 1 with size 0
The bug might be caused by no condition set on max_new_tokens inside the generate() function, which might be a general bug for generation and not only for prompting.
Note
Also, as I was reading the code I noticed this line:
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
It slices the text prompt ids and takes (self.config.max_length // 2 + 1) tokens instead of (self.config.max_length // 2 - 1) as taken in the original code of Whisper here.
Expected behavior
- Clear warning or error about surpassing the
model.max_length. - Being able to set
max_new_tokens=224 ( = max_length // 2)during prompting.
mohblnk
Metadata
Metadata
Assignees
Labels
No labels