|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | """ Testing suite for the PyTorch Wav2Vec2-Conformer model. """ |
16 | | - |
17 | 16 | import math |
| 17 | +import tempfile |
18 | 18 | import unittest |
19 | 19 |
|
20 | 20 | import numpy as np |
21 | 21 | from datasets import load_dataset |
22 | 22 |
|
23 | 23 | from transformers import Wav2Vec2ConformerConfig, is_torch_available |
24 | | -from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device |
| 24 | +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device |
25 | 25 |
|
26 | 26 | from ...test_configuration_common import ConfigTester |
27 | 27 | from ...test_modeling_common import ( |
@@ -215,6 +215,23 @@ def create_and_check_model_with_adapter_proj_dim(self, config, input_values, att |
215 | 215 | (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), |
216 | 216 | ) |
217 | 217 |
|
| 218 | + def create_and_check_model_float16(self, config, input_values, attention_mask): |
| 219 | + model = Wav2Vec2ConformerModel(config=config) |
| 220 | + |
| 221 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 222 | + model.save_pretrained(tmpdirname) |
| 223 | + model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16) |
| 224 | + |
| 225 | + model.to(torch_device) |
| 226 | + model.eval() |
| 227 | + |
| 228 | + with torch.no_grad(): |
| 229 | + result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask) |
| 230 | + |
| 231 | + self.parent.assertEqual( |
| 232 | + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) |
| 233 | + ) |
| 234 | + |
218 | 235 | def create_and_check_batch_inference(self, config, input_values, *args): |
219 | 236 | # test does not pass for models making use of `group_norm` |
220 | 237 | # check: https://github.com/pytorch/fairseq/issues/3227 |
@@ -451,6 +468,16 @@ def test_model_with_adapter_proj_dim(self): |
451 | 468 | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
452 | 469 | self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) |
453 | 470 |
|
| 471 | + @require_torch_gpu |
| 472 | + def test_model_float16_with_relative(self): |
| 473 | + config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative") |
| 474 | + self.model_tester.create_and_check_model_float16(*config_and_inputs) |
| 475 | + |
| 476 | + @require_torch_gpu |
| 477 | + def test_model_float16_with_rotary(self): |
| 478 | + config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary") |
| 479 | + self.model_tester.create_and_check_model_float16(*config_and_inputs) |
| 480 | + |
454 | 481 | def test_ctc_loss_inference(self): |
455 | 482 | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
456 | 483 | self.model_tester.check_ctc_loss(*config_and_inputs) |
|
0 commit comments