diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 7bf7e0af0ad1..981f54aad48e 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -303,14 +303,12 @@ def compute_loss( inputs, attention_mask, decoder_input_ids, - decoder_attention_mask, freeze_feature_encoder: bool = False, ): outputs_enc_dec = enc_dec_model( inputs=inputs, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, freeze_feature_encoder=freeze_feature_encoder, params=params, ) @@ -323,13 +321,11 @@ def compute_loss( grad_fn = jax.value_and_grad(compute_loss) # compute the loss and gradients for the unfrozen model - loss, grads = grad_fn( - params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False - ) + loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False) # compare to the loss and gradients for the frozen model loss_frozen, grads_frozen = grad_fn( - params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True + params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True ) self.assert_almost_equals(loss, loss_frozen, 1e-5) @@ -348,14 +344,14 @@ def compute_loss( feature_extractor_grads, feature_extractor_grads_frozen ): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) - self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8) + self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10) # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): - self.assert_almost_equals(grad, grad_frozen, 1e-8) + self.assert_almost_equals(grad, grad_frozen, 1e-10) def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):