diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 4af7d609f03d..1de846857124 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -843,6 +843,10 @@ def __post_init__(self): if self.random_replace_prob < 0 or self.random_replace_prob > 1: raise ValueError("random_replace_prob should be between 0 and 1.") + self.mlm_probability = float(self.mlm_probability) + self.mask_replace_prob = float(self.mask_replace_prob) + self.random_replace_prob = float(self.random_replace_prob) + if self.tf_experimental_compile: import tensorflow as tf diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index c3e9b5a3badf..d631299c01f6 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -1052,7 +1052,9 @@ def test_all_mask_replacement(self): # confirm that every token is either the original token or [MASK] self.assertTrue( - tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)) + tf.reduce_all( + (batch["input_ids"] == tf.cast(inputs, tf.int64)) | (batch["input_ids"] == tokenizer.mask_token_id) + ) ) # numpy call