-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Closed
Description
System Info
PC: M2
transformers== 4.31.0.dev0
refer: openai/whisper#1478
meet the error:
in <module>:9 │
│ │
│ 6 prompt_ids = processor.get_prompt_ids(prompt) │
│ 7 │
│ 8 forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe") │
│ ❱ 9 predicted_ids = model.generate(input_features, prompt_ids=prompt_ids, forced_decoder_ids │
│ 10 │ │ │ │ │ │ │ max_new_tokens=3000) │
│ 11 transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) │
│ 12 print("耗时:", time.time() - start_time, transcription) │
│ │
│ /Users/diaojunxian/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/whisper/mo │
│ deling_whisper.py:1664 in generate │
│ │
│ 1661 │ │ if generation_config.return_timestamps: │
│ 1662 │ │ │ logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] │
│ 1663 │ │ │
│ ❱ 1664 │ │ return super().generate( │
│ 1665 │ │ │ inputs, │
│ 1666 │ │ │ generation_config, │
│ 1667 │ │ │ logits_processor, │
│ │
│ /Users/diaojunxian/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 │
│ in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /Users/diaojunxian/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/generation/utils. │
│ py:1522 in generate │
│ │
│ 1519 │ │ │ │ ) │
│ 1520 │ │ │ │
│ 1521 │ │ │ # 11. run greedy search │
│ ❱ 1522 │ │ │ return self.greedy_search( │
│ 1523 │ │ │ │ input_ids, │
│ 1524 │ │ │ │ logits_processor=logits_processor, │
│ 1525 │ │ │ │ stopping_criteria=stopping_criteria, │
│ │
│ /Users/diaojunxian/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/generation/utils. │
│ py:2349 in greedy_search │
│ │
│ 2346 │ │ │ if synced_gpus and this_peer_finished: │
│ 2347 │ │ │ │ continue # don't waste resources running the code we don't need │
│ 2348 │ │ │ │
│ ❱ 2349 │ │ │ next_token_logits = outputs.logits[:, -1, :] │
│ 2350 │ │ │ │
│ 2351 │ │ │ # pre-process distribution │
│ 2352 │ │ │ next_tokens_scores = logits_processor(input_ids, next_token_logits)
use these code all occur error.
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import librosa
import soundfile
import torchaudio
base_model = "/Users/ddd/Documents/github/whisper-large-v2"
processor = WhisperProcessor.from_pretrained(base_model,
language="zh",
task="transcribe",
local_files_only="True")
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
# 获取模型
model = WhisperForConditionalGeneration.from_pretrained(base_model,
device_map="auto",
local_files_only=True).half()
model.eval()
audio_file = "/Users/ddd/Documents/gitlab/llm-train/yuyin/simple.m4a"
src_signal, sample_rate = librosa.load(audio_file, sr=16000)
start = 23196064
end = 23364576
src_signal_demo = src_signal[start:end]
input_features = processor(src_signal_demo, sampling_rate=sample_rate, return_tensors="pt").input_features.half().to("mps")
prompt = '以下是普通话的句子'
prompt_ids = processor.get_prompt_ids(prompt)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
predicted_ids = model.generate(input_features, prompt_ids=prompt_ids, forced_decoder_ids=forced_decoder_ids,
max_new_tokens=3000)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
from transformers import pipeline
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large-v2",
device="mps",
chunk_length_s=30, # if not precised then only generate as much as `max_new_tokens`
generate_kwargs = {"num_beams": 5} # same as setting as "openai whisper" default
)
audio_file = "/Users/ddd/Documents/gitlab/llm-train/yuyin/simple.m4a"
src_signal, sample_rate = librosa.load(audio_file, sr=16000)
start = 23196064
end = 23364576
src_signal_demo = src_signal[start:end]
prompt = '以下是普通话的句子'
prompt_ids = pipe.tokenizer.get_prompt_ids(prompt, return_tensors="pt")
result = pipe(src_signal_demo, generate_kwargs={"language": "zh", "task": "transcribe", "prompt_ids": prompt_ids})
print(result["text"])
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- load the audio
- slice the audio
- add the prompt
- transcribe the slice audio, then occur error.
Expected behavior
the audio can transform to the context.
Metadata
Metadata
Assignees
Labels
No labels