Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,12 @@ def compute_loss(
inputs,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
freeze_feature_encoder: bool = False,
):
outputs_enc_dec = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
freeze_feature_encoder=freeze_feature_encoder,
params=params,
)
Expand All @@ -323,13 +321,11 @@ def compute_loss(
grad_fn = jax.value_and_grad(compute_loss)

# compute the loss and gradients for the unfrozen model
loss, grads = grad_fn(
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
)
loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False)

# compare to the loss and gradients for the frozen model
loss_frozen, grads_frozen = grad_fn(
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True
params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True
)

self.assert_almost_equals(loss, loss_frozen, 1e-5)
Expand All @@ -348,14 +344,14 @@ def compute_loss(
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8)
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it that aggressive btw :-) We usually are happy with 1e-4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I just set it in-tandem with the assert_almost_equals threshold that is run for the unfrozen gradients.


# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)

for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-8)
self.assert_almost_equals(grad, grad_frozen, 1e-10)

def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

Expand Down