diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 61f4e88f22fb..dce699136535 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing as mp import random import warnings from collections.abc import Mapping @@ -787,6 +788,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): If set, will pad the sequence to a multiple of the provided value. return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". + seed (`int`, *optional*): + The seed to use for the random number generator for masking. If not provided, the global RNG will be used. @@ -827,6 +830,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): pad_to_multiple_of: Optional[int] = None tf_experimental_compile: bool = False return_tensors: str = "pt" + seed: Optional[int] = None def __post_init__(self): if self.mlm and self.tokenizer.mask_token is None: @@ -852,12 +856,57 @@ def __post_init__(self): self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True) + self.generator = None + + def get_generator(self, seed): + if self.return_tensors == "pt": + import torch + + return torch.Generator().manual_seed(seed) + elif self.return_tensors == "tf": + import tensorflow as tf + + return tf.random.Generator.from_seed(seed) + else: + import numpy as np + + return np.random.default_rng(seed) + + def create_rng(self): + if mp.current_process().name == "MainProcess": + # If we are in the main process, we create a generator object with the seed + self.generator = self.get_generator(self.seed) + else: + # If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each + # worker's generator, generated as the main seed + the worker's ID. + # (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading) + # Only PyTorch DataLoader allows us to access the worker ID, and so we check for this. + # For other frameworks, we will throw an error. + import torch + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + error_string = ( + "Worker process information is not available for seeding the generator. This may be because", + "you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can", + "only be used when using multiprocessing with a PyTorch DataLoader. Please either use a", + "single process or use a PyTorch DataLoader with multiple workers.", + ) + raise ValueError(error_string) + + self.generator = self.get_generator(self.seed + worker_info.id) + @staticmethod - def tf_bernoulli(shape, probability): + def tf_bernoulli(shape, probability, generator=None): import tensorflow as tf prob_matrix = tf.fill(shape, probability) - return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool) + # if generator exists, use it to generate the random numbers + # otherwise, use the global RNG + if generator: + return tf.cast(prob_matrix - generator.uniform(shape, 0, 1) >= 0, tf.bool) + else: + return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool) def tf_mask_tokens( self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None @@ -872,12 +921,12 @@ def tf_mask_tokens( input_shape = tf.shape(inputs) # 1 for a special token, 0 for a normal token in the special tokens mask # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) - masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask + masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability, self.generator) & ~special_tokens_mask # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens labels = tf.where(masked_indices, inputs, -100) # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices + indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices inputs = tf.where(indices_replaced, mask_token_id, inputs) @@ -891,9 +940,15 @@ def tf_mask_tokens( random_replace_prob_scaled = self.random_replace_prob / remaining_prob # random_replace_prob% of the time, we replace masked input tokens with random word indices_random = ( - self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced + self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator) + & masked_indices + & ~indices_replaced ) - random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) + + if self.generator: + random_words = self.generator.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) + else: + random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) inputs = tf.where(indices_random, random_words, inputs) @@ -903,6 +958,11 @@ def tf_mask_tokens( def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: import tensorflow as tf + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + # Handle dict or lists with proper padding and conversion to tensor. if isinstance(examples[0], Mapping): batch = pad_without_fast_tokenizer_warning( @@ -943,6 +1003,12 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. + + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + if isinstance(examples[0], Mapping): batch = pad_without_fast_tokenizer_warning( self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of @@ -983,11 +1049,14 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No special_tokens_mask = special_tokens_mask.bool() probability_matrix.masked_fill_(special_tokens_mask, value=0.0) - masked_indices = torch.bernoulli(probability_matrix).bool() + masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices + indices_replaced = ( + torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool() + & masked_indices + ) inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) if self.mask_replace_prob == 1 or self.random_replace_prob == 0: @@ -1001,11 +1070,11 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No # random_replace_prob% of the time, we replace masked input tokens with random word indices_random = ( - torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool() + torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool() & masked_indices & ~indices_replaced ) - random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) + random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator) inputs[indices_random] = random_words[indices_random] # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged @@ -1013,6 +1082,12 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. + + if self.seed and self.generator is None: + # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. + # If no seed supplied, we will use the global RNG + self.create_rng() + if isinstance(examples[0], Mapping): batch = pad_without_fast_tokenizer_warning( self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of @@ -1052,13 +1127,21 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No probability_matrix[special_tokens_mask] = 0 # Numpy doesn't have bernoulli, so we use a binomial with 1 trial - masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool) + if self.generator: + masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool) + else: + masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool) labels[~masked_indices] = -100 # We only compute loss on masked tokens # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = ( - np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices - ) + if self.generator: + indices_replaced = ( + self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices + ) + else: + indices_replaced = ( + np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices + ) inputs[indices_replaced] = self.tokenizer.mask_token_id if self.mask_replace_prob == 1 or self.random_replace_prob == 0: @@ -1069,14 +1152,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No # mask_replace_prob = 0.8 and random_replace_prob = 0.1, # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 random_replace_prob_scaled = self.random_replace_prob / remaining_prob - indices_random = ( - np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) - & masked_indices - & ~indices_replaced - ) - random_words = np.random.randint( - low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64 - ) + if self.generator: + indices_random = ( + self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) + & masked_indices + & ~indices_replaced + ) + random_words = self.generator.integers( + low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64 + ) + else: + indices_random = ( + np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) + & masked_indices + & ~indices_replaced + ) + random_words = np.random.randint( + low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64 + ) inputs[indices_random] = random_words # The rest of the time (10% of the time) we keep the masked input tokens unchanged diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index d631299c01f6..ca88b3c79c3e 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -350,6 +350,86 @@ def test_data_collator_for_language_modeling(self): pad_features = [list(range(5)), list(range(10))] self._test_no_pad_and_pad(no_pad_features, pad_features) + def test_data_collator_for_language_modeling_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForLanguageModeling instances + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42) + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000))) + self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000))) + + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42) + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000))) + self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000))) + + self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"])) + + # check if seed is respected in multiple workers situation + features = [{"input_ids": list(range(1000))} for _ in range(10)] + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + generator=torch.Generator().manual_seed(42), + collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42), + ) + + batch_3_input_ids = [] + batch_3_labels = [] + for batch in dataloader: + batch_3_input_ids.append(batch["input_ids"]) + batch_3_labels.append(batch["labels"]) + + batch_3_input_ids = torch.stack(batch_3_input_ids) + batch_3_labels = torch.stack(batch_3_labels) + self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000))) + + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42), + ) + + batch_4_input_ids = [] + batch_4_labels = [] + for batch in dataloader: + batch_4_input_ids.append(batch["input_ids"]) + batch_4_labels.append(batch["labels"]) + batch_4_input_ids = torch.stack(batch_4_input_ids) + batch_4_labels = torch.stack(batch_4_labels) + self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000))) + + self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids)) + self.assertTrue(torch.all(batch_3_labels == batch_4_labels)) + + # try with different seed + dataloader = torch.utils.data.DataLoader( + features, + batch_size=2, + num_workers=2, + collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=43), + ) + + batch_5_input_ids = [] + batch_5_labels = [] + for batch in dataloader: + batch_5_input_ids.append(batch["input_ids"]) + batch_5_labels.append(batch["labels"]) + batch_5_input_ids = torch.stack(batch_5_input_ids) + batch_5_labels = torch.stack(batch_5_labels) + self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000))) + self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000))) + + self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids)) + self.assertFalse(torch.all(batch_3_labels == batch_5_labels)) + def test_data_collator_for_whole_word_mask(self): tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt") @@ -1077,6 +1157,33 @@ def test_data_collator_for_language_modeling(self): pad_features = [list(range(5)), list(range(10))] self._test_no_pad_and_pad(no_pad_features, pad_features) + def test_data_collator_for_language_modeling_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForLanguageModeling instances + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf") + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000]) + + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="tf") + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000]) + + self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"])) + + # try with different seed + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="tf") + batch_3 = data_collator(features) + self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000]) + self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000]) + + self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"])) + self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"])) + def test_data_collator_for_whole_word_mask(self): tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf") @@ -1772,6 +1879,32 @@ def test_data_collator_for_language_modeling(self): pad_features = [list(range(5)), list(range(10))] self._test_no_pad_and_pad(no_pad_features, pad_features) + def test_data_collator_for_language_modeling_with_seed(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}] + + # check if seed is respected between two different DataCollatorForLanguageModeling instances + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np") + batch_1 = data_collator(features) + self.assertEqual(batch_1["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_1["labels"].shape, (2, 1000)) + + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=42, return_tensors="np") + batch_2 = data_collator(features) + self.assertEqual(batch_2["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_2["labels"].shape, (2, 1000)) + + self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"])) + self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"])) + + data_collator = DataCollatorForLanguageModeling(tokenizer, seed=43, return_tensors="np") + batch_3 = data_collator(features) + self.assertEqual(batch_3["input_ids"].shape, (2, 1000)) + self.assertEqual(batch_3["labels"].shape, (2, 1000)) + + self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"])) + self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"])) + def test_data_collator_for_whole_word_mask(self): tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")