|
| 1 | +<!--- |
| 2 | +Copyright 2023 The HuggingFace Team. All rights reserved. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +--> |
| 16 | + |
| 17 | +# Automatic Speech Recognition - Flax Examples |
| 18 | + |
| 19 | +## Sequence to Sequence |
| 20 | + |
| 21 | +The script [`run_flax_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py) |
| 22 | +can be used to fine-tune any [Flax Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.FlaxAutoModelForSpeechSeq2Seq) |
| 23 | +for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) |
| 24 | +or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model, |
| 25 | +an example for which is included below. |
| 26 | + |
| 27 | +### Whisper Model |
| 28 | + |
| 29 | +We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model |
| 30 | +weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary |
| 31 | +training hyperparameters. |
| 32 | + |
| 33 | +The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint |
| 34 | +on the Hindi subset of the [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset. |
| 35 | +Note that before running this script you must accept the dataset's [terms of use](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) |
| 36 | +and register your Hugging Face Hub token on your device by running `huggingface-hub login`. |
| 37 | + |
| 38 | +```bash |
| 39 | +python run_flax_speech_recognition_seq2seq.py \ |
| 40 | + --model_name_or_path="openai/whisper-small" \ |
| 41 | + --dataset_name="mozilla-foundation/common_voice_13_0" \ |
| 42 | + --dataset_config_name="hi" \ |
| 43 | + --language="hindi" \ |
| 44 | + --train_split_name="train+validation" \ |
| 45 | + --eval_split_name="test" \ |
| 46 | + --output_dir="./whisper-small-hi-flax" \ |
| 47 | + --per_device_train_batch_size="16" \ |
| 48 | + --per_device_eval_batch_size="16" \ |
| 49 | + --num_train_epochs="10" \ |
| 50 | + --learning_rate="1e-4" \ |
| 51 | + --warmup_steps="500" \ |
| 52 | + --logging_steps="25" \ |
| 53 | + --generation_max_length="40" \ |
| 54 | + --preprocessing_num_workers="32" \ |
| 55 | + --dataloader_num_workers="32" \ |
| 56 | + --max_duration_in_seconds="30" \ |
| 57 | + --text_column_name="sentence" \ |
| 58 | + --overwrite_output_dir \ |
| 59 | + --do_train \ |
| 60 | + --do_eval \ |
| 61 | + --predict_with_generate \ |
| 62 | + --push_to_hub \ |
| 63 | + --use_auth_token |
| 64 | +``` |
| 65 | + |
| 66 | +On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error |
| 67 | +rate of **34%**. See the checkpoint [sanchit-gandhi/whisper-small-hi-flax](https://huggingface.co/sanchit-gandhi/whisper-small-hi-flax) |
| 68 | +for an example training run. |
0 commit comments