Skip to content

Commit 8d51801

Browse files
[Wav2Vec2 Conformer] Fix inference float16 (#25985)
* [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test
1 parent 6bc517c commit 8d51801

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,15 @@ def forward(self, hidden_states):
406406
return self.cached_rotary_positional_embedding
407407

408408
self.cached_sequence_length = sequence_length
409+
# Embeddings are computed in the dtype of the inv_freq constant
409410
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
410411
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
411412
embeddings = torch.cat((freqs, freqs), dim=-1)
412413

413414
cos_embeddings = embeddings.cos()[:, None, None, :]
414415
sin_embeddings = embeddings.sin()[:, None, None, :]
415-
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
416+
# Computed embeddings are cast to the dtype of the hidden state inputs
417+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
416418
return self.cached_rotary_positional_embedding
417419

418420

tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
16-
1716
import math
17+
import tempfile
1818
import unittest
1919

2020
import numpy as np
2121
from datasets import load_dataset
2222

2323
from transformers import Wav2Vec2ConformerConfig, is_torch_available
24-
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
24+
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
2525

2626
from ...test_configuration_common import ConfigTester
2727
from ...test_modeling_common import (
@@ -215,6 +215,23 @@ def create_and_check_model_with_adapter_proj_dim(self, config, input_values, att
215215
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
216216
)
217217

218+
def create_and_check_model_float16(self, config, input_values, attention_mask):
219+
model = Wav2Vec2ConformerModel(config=config)
220+
221+
with tempfile.TemporaryDirectory() as tmpdirname:
222+
model.save_pretrained(tmpdirname)
223+
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
224+
225+
model.to(torch_device)
226+
model.eval()
227+
228+
with torch.no_grad():
229+
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
230+
231+
self.parent.assertEqual(
232+
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
233+
)
234+
218235
def create_and_check_batch_inference(self, config, input_values, *args):
219236
# test does not pass for models making use of `group_norm`
220237
# check: https://github.com/pytorch/fairseq/issues/3227
@@ -451,6 +468,16 @@ def test_model_with_adapter_proj_dim(self):
451468
config_and_inputs = self.model_tester.prepare_config_and_inputs()
452469
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
453470

471+
@require_torch_gpu
472+
def test_model_float16_with_relative(self):
473+
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
474+
self.model_tester.create_and_check_model_float16(*config_and_inputs)
475+
476+
@require_torch_gpu
477+
def test_model_float16_with_rotary(self):
478+
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
479+
self.model_tester.create_and_check_model_float16(*config_and_inputs)
480+
454481
def test_ctc_loss_inference(self):
455482
config_and_inputs = self.model_tester.prepare_config_and_inputs()
456483
self.model_tester.check_ctc_loss(*config_and_inputs)

tests/pipelines/test_pipelines_automatic_speech_recognition.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,26 @@ def test_speech_to_text_leveraged(self):
901901
output = speech_recognizer(filename)
902902
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
903903

904+
@slow
905+
@require_torch_gpu
906+
def test_wav2vec2_conformer_float16(self):
907+
speech_recognizer = pipeline(
908+
task="automatic-speech-recognition",
909+
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
910+
device="cuda:0",
911+
torch_dtype=torch.float16,
912+
framework="pt",
913+
)
914+
915+
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
916+
sample = dataset[0]["audio"]
917+
918+
output = speech_recognizer(sample)
919+
self.assertEqual(
920+
output,
921+
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
922+
)
923+
904924
@require_torch
905925
def test_chunking_fast(self):
906926
speech_recognizer = pipeline(

0 commit comments

Comments
 (0)