Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,15 @@ def forward(self, hidden_states):
return self.cached_rotary_positional_embedding

self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks a lot like:

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Wondering if we can add copied from and use this / wondering if the dynamic scaling could also work for audio models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use # Copied from on the whole module since the Wav2Vec2ConformerRotaryPositionalEmbedding accepts the config as an argument, but LlamaRotaryEmbedding uses various ad-hoc arguments. But we could do a similar dynamic slicing - will add this in a follow-up PR so as not to block @Vaibhavs10

time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1)

cos_embeddings = embeddings.cos()[:, None, None, :]
sin_embeddings = embeddings.sin()[:, None, None, :]
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
# Computed embeddings are cast to the dtype of the hidden state inputs
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
return self.cached_rotary_positional_embedding


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """

import math
import tempfile
import unittest

import numpy as np
from datasets import load_dataset

from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
Expand Down Expand Up @@ -215,6 +215,23 @@ def create_and_check_model_with_adapter_proj_dim(self, config, input_values, att
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)

def create_and_check_model_float16(self, config, input_values, attention_mask):
model = Wav2Vec2ConformerModel(config=config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)

model.to(torch_device)
model.eval()

with torch.no_grad():
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)

self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)

def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
Expand Down Expand Up @@ -451,6 +468,16 @@ def test_model_with_adapter_proj_dim(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)

@require_torch_gpu
def test_model_float16_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model_float16(*config_and_inputs)

@require_torch_gpu
def test_model_float16_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model_float16(*config_and_inputs)

def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
Expand Down
20 changes: 20 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,26 @@ def test_speech_to_text_leveraged(self):
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
@require_torch_gpu
def test_wav2vec2_conformer_float16(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the error repro that was failing before @Vaibhavs10 - added a slow integration test to make sure this works after the fix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thanks <3

speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
device="cuda:0",
torch_dtype=torch.float16,
framework="pt",
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

output = speech_recognizer(sample)
self.assertEqual(
output,
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
)

@require_torch
def test_chunking_fast(self):
speech_recognizer = pipeline(
Expand Down