Skip to content

Commit 1da84ae

Browse files
Fix Bug in Flax-Speech-Encoder-Decoder Test (#16041)
* Fix Bug in Flax-Speech-Encoder-Decoder Test * change thresholds for CPU precision
1 parent b2a1c99 commit 1da84ae

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,12 @@ def compute_loss(
303303
inputs,
304304
attention_mask,
305305
decoder_input_ids,
306-
decoder_attention_mask,
307306
freeze_feature_encoder: bool = False,
308307
):
309308
outputs_enc_dec = enc_dec_model(
310309
inputs=inputs,
311310
attention_mask=attention_mask,
312311
decoder_input_ids=decoder_input_ids,
313-
decoder_attention_mask=decoder_attention_mask,
314312
freeze_feature_encoder=freeze_feature_encoder,
315313
params=params,
316314
)
@@ -323,13 +321,11 @@ def compute_loss(
323321
grad_fn = jax.value_and_grad(compute_loss)
324322

325323
# compute the loss and gradients for the unfrozen model
326-
loss, grads = grad_fn(
327-
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
328-
)
324+
loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False)
329325

330326
# compare to the loss and gradients for the frozen model
331327
loss_frozen, grads_frozen = grad_fn(
332-
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True
328+
params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True
333329
)
334330

335331
self.assert_almost_equals(loss, loss_frozen, 1e-5)
@@ -348,14 +344,14 @@ def compute_loss(
348344
feature_extractor_grads, feature_extractor_grads_frozen
349345
):
350346
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
351-
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8)
347+
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10)
352348

353349
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
354350
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
355351
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
356352

357353
for grad, grad_frozen in zip(grads, grads_frozen):
358-
self.assert_almost_equals(grad, grad_frozen, 1e-8)
354+
self.assert_almost_equals(grad, grad_frozen, 1e-10)
359355

360356
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
361357

0 commit comments

Comments
 (0)