From 1c1f4acbae214b09bdaa9d38933c9e7e3ac678e8 Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Fri, 24 Jan 2025 06:51:03 +0000 Subject: [PATCH 1/3] Fix device mismatch error in whisper feature extraction --- .../models/whisper/feature_extraction_whisper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 3c4d413d88e6..0e1864ba51e9 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -131,15 +131,14 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") """ waveform = torch.from_numpy(waveform).type(torch.float32) - window = torch.hann_window(self.n_fft) - if device != "cpu": + window = torch.hann_window(self.n_fft, device=device) + if device != waveform.device: waveform = waveform.to(device) - window = window.to(device) stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) - if device != "cpu": + if device != mel_filters.device: mel_filters = mel_filters.to(device) mel_spec = mel_filters.T @ magnitudes From 2f2a914399bedcd913e0be1f0a1bf216089236de Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Mon, 27 Jan 2025 07:42:12 +0000 Subject: [PATCH 2/3] Set default device --- tests/models/whisper/test_feature_extraction_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index 4b2353bce002..cefdd8f77899 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -298,6 +298,7 @@ def test_torch_integration_batch(self): ) # fmt: on + torch.set_default_device("cuda") input_speech = self._load_datasamples(3) feature_extractor = WhisperFeatureExtractor() input_features = feature_extractor(input_speech, return_tensors="pt").input_features From d133b3f20def1fe3fae35805e43578e015a724dc Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Tue, 28 Jan 2025 04:59:06 +0000 Subject: [PATCH 3/3] Address code review feedback --- .../models/whisper/feature_extraction_whisper.py | 10 +++------- .../models/whisper/test_feature_extraction_whisper.py | 8 ++++---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 0e1864ba51e9..1519fb028623 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -129,17 +129,13 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching, yielding results similar to cpu computing with 1e-5 tolerance. """ - waveform = torch.from_numpy(waveform).type(torch.float32) - + waveform = torch.from_numpy(waveform).to(device, torch.float32) window = torch.hann_window(self.n_fft, device=device) - if device != waveform.device: - waveform = waveform.to(device) + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 - mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) - if device != mel_filters.device: - mel_filters = mel_filters.to(device) + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) mel_spec = mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index cefdd8f77899..0c0c3797ab55 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -298,9 +298,9 @@ def test_torch_integration_batch(self): ) # fmt: on - torch.set_default_device("cuda") - input_speech = self._load_datasamples(3) - feature_extractor = WhisperFeatureExtractor() - input_features = feature_extractor(input_speech, return_tensors="pt").input_features + with torch.device("cuda"): + input_speech = self._load_datasamples(3) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="pt").input_features self.assertEqual(input_features.shape, (3, 80, 3000)) self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))