Skip to content

Commit 68e85fc

Browse files
[Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764)
* from seq2seq speech * [Flax] Example script for speech seq2seq * tests and fixes * make style * fix: label padding tokens * fix: label padding tokens over list * update ln names for Whisper * try datasets iter loader * create readme and append results * style * make style * adjust lr * use pt dataloader * make fast * pin gen max len * finish * add pt to requirements for test * fix pt -> torch * add accelerate
1 parent 3911774 commit 68e85fc

File tree

5 files changed

+967
-1
lines changed

5 files changed

+967
-1
lines changed

examples/flax/_tests_requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ nltk
55
rouge-score
66
seqeval
77
tensorboard
8-
evaluate >= 0.2.0
8+
evaluate >= 0.2.0
9+
torch
10+
accelerate
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
<!---
2+
Copyright 2023 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Automatic Speech Recognition - Flax Examples
18+
19+
## Sequence to Sequence
20+
21+
The script [`run_flax_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py)
22+
can be used to fine-tune any [Flax Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.FlaxAutoModelForSpeechSeq2Seq)
23+
for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition)
24+
or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model,
25+
an example for which is included below.
26+
27+
### Whisper Model
28+
29+
We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model
30+
weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary
31+
training hyperparameters.
32+
33+
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint
34+
on the Hindi subset of the [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset.
35+
Note that before running this script you must accept the dataset's [terms of use](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0)
36+
and register your Hugging Face Hub token on your device by running `huggingface-hub login`.
37+
38+
```bash
39+
python run_flax_speech_recognition_seq2seq.py \
40+
--model_name_or_path="openai/whisper-small" \
41+
--dataset_name="mozilla-foundation/common_voice_13_0" \
42+
--dataset_config_name="hi" \
43+
--language="hindi" \
44+
--train_split_name="train+validation" \
45+
--eval_split_name="test" \
46+
--output_dir="./whisper-small-hi-flax" \
47+
--per_device_train_batch_size="16" \
48+
--per_device_eval_batch_size="16" \
49+
--num_train_epochs="10" \
50+
--learning_rate="1e-4" \
51+
--warmup_steps="500" \
52+
--logging_steps="25" \
53+
--generation_max_length="40" \
54+
--preprocessing_num_workers="32" \
55+
--dataloader_num_workers="32" \
56+
--max_duration_in_seconds="30" \
57+
--text_column_name="sentence" \
58+
--overwrite_output_dir \
59+
--do_train \
60+
--do_eval \
61+
--predict_with_generate \
62+
--push_to_hub \
63+
--use_auth_token
64+
```
65+
66+
On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error
67+
rate of **34%**. See the checkpoint [sanchit-gandhi/whisper-small-hi-flax](https://huggingface.co/sanchit-gandhi/whisper-small-hi-flax)
68+
for an example training run.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
datasets[audio]>=2.14.0
2+
jax>=0.3.6
3+
jaxlib>=0.3.6
4+
flax>=0.4.1
5+
optax>=0.0.8
6+
torch>=1.9.0
7+
jiwer
8+
evaluate

0 commit comments

Comments
 (0)