Skip to content

Commit 00366e0

Browse files
capemoxRocketknight1
authored andcommitted
Fixed 2 issues regarding tests/trainer/test_data_collator.py::TFDataCollatorIntegrationTest::test_all_mask_replacement:
1. I got the error `RuntimeError: "bernoulli_tensor_cpu_p_" not implemented for 'Long'`. This is because the `mask_replacement_prob=1` and `torch.bernoulli` doesn't accept this type (which would be a `torch.long` dtype instead. I fixed this by manually casting the probability arguments in the `__post_init__` function of `DataCollatorForLanguageModeling`. 2. I also got the error `tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute Equal as input #1(zero-based) was expected to be a int64 tensor but is a int32 tensor [Op:Equal]` due to the line `tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))` in `test_data_collator.py`. This occurs because the type of the `inputs` variable is `tf.int32`. Solved this by manually casting it to `tf.int64` in the test, as the expected return type of `batch["input_ids"]` is `tf.int64`.
1 parent 02776d2 commit 00366e0

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/transformers/data/data_collator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,10 @@ def __post_init__(self):
843843
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
844844
raise ValueError("random_replace_prob should be between 0 and 1.")
845845

846+
self.mlm_probability = float(self.mlm_probability)
847+
self.mask_replace_prob = float(self.mask_replace_prob)
848+
self.random_replace_prob = float(self.random_replace_prob)
849+
846850
if self.tf_experimental_compile:
847851
import tensorflow as tf
848852

tests/trainer/test_data_collator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,9 @@ def test_all_mask_replacement(self):
10521052

10531053
# confirm that every token is either the original token or [MASK]
10541054
self.assertTrue(
1055-
tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))
1055+
tf.reduce_all(
1056+
(batch["input_ids"] == tf.cast(inputs, tf.int64)) | (batch["input_ids"] == tokenizer.mask_token_id)
1057+
)
10561058
)
10571059

10581060
# numpy call

0 commit comments

Comments
 (0)