@@ -303,14 +303,12 @@ def compute_loss(
303303 inputs ,
304304 attention_mask ,
305305 decoder_input_ids ,
306- decoder_attention_mask ,
307306 freeze_feature_encoder : bool = False ,
308307 ):
309308 outputs_enc_dec = enc_dec_model (
310309 inputs = inputs ,
311310 attention_mask = attention_mask ,
312311 decoder_input_ids = decoder_input_ids ,
313- decoder_attention_mask = decoder_attention_mask ,
314312 freeze_feature_encoder = freeze_feature_encoder ,
315313 params = params ,
316314 )
@@ -323,13 +321,11 @@ def compute_loss(
323321 grad_fn = jax .value_and_grad (compute_loss )
324322
325323 # compute the loss and gradients for the unfrozen model
326- loss , grads = grad_fn (
327- params , inputs , attention_mask , decoder_input_ids , decoder_attention_mask , freeze_feature_encoder = False
328- )
324+ loss , grads = grad_fn (params , inputs , attention_mask , decoder_input_ids , freeze_feature_encoder = False )
329325
330326 # compare to the loss and gradients for the frozen model
331327 loss_frozen , grads_frozen = grad_fn (
332- params , inputs , attention_mask , decoder_input_ids , decoder_attention_mask , freeze_feature_encoder = True
328+ params , inputs , attention_mask , decoder_input_ids , freeze_feature_encoder = True
333329 )
334330
335331 self .assert_almost_equals (loss , loss_frozen , 1e-5 )
@@ -348,14 +344,14 @@ def compute_loss(
348344 feature_extractor_grads , feature_extractor_grads_frozen
349345 ):
350346 self .assertTrue ((feature_extractor_grad_frozen == 0.0 ).all ())
351- self .assert_difference (feature_extractor_grad , feature_extractor_grad_frozen , 1e-8 )
347+ self .assert_difference (feature_extractor_grad , feature_extractor_grad_frozen , 1e-10 )
352348
353349 # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
354350 grads = tuple (grads [k ] for k in grads if "feature_extractor" not in k )
355351 grads_frozen = tuple (grads_frozen [k ] for k in grads_frozen if "feature_extractor" not in k )
356352
357353 for grad , grad_frozen in zip (grads , grads_frozen ):
358- self .assert_almost_equals (grad , grad_frozen , 1e-8 )
354+ self .assert_almost_equals (grad , grad_frozen , 1e-10 )
359355
360356 def check_pt_flax_equivalence (self , pt_model , fx_model , inputs_dict ):
361357
0 commit comments