Skip to content

Commit ef9a03b

Browse files
capemoxRocketknight1
authored andcommitted
Add support for seed in DataCollatorForLanguageModeling. Also wrote tests for verifying behaviour.
1 parent ecd60d0 commit ef9a03b

File tree

2 files changed

+248
-22
lines changed

2 files changed

+248
-22
lines changed

src/transformers/data/data_collator.py

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import multiprocessing as mp
1516
import random
1617
import warnings
1718
from collections.abc import Mapping
@@ -787,6 +788,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
787788
If set, will pad the sequence to a multiple of the provided value.
788789
return_tensors (`str`):
789790
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
791+
seed (`int`, *optional*):
792+
The seed to use for the random number generator for masking. If not provided, the global RNG will be used.
790793
791794
<Tip>
792795
@@ -827,6 +830,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
827830
pad_to_multiple_of: Optional[int] = None
828831
tf_experimental_compile: bool = False
829832
return_tensors: str = "pt"
833+
seed: Optional[int] = None
830834

831835
def __post_init__(self):
832836
if self.mlm and self.tokenizer.mask_token is None:
@@ -852,12 +856,57 @@ def __post_init__(self):
852856

853857
self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
854858

859+
self.generator = None
860+
861+
def get_generator(self, seed):
862+
if self.return_tensors == "pt":
863+
import torch
864+
865+
return torch.Generator().manual_seed(seed)
866+
elif self.return_tensors == "tf":
867+
import tensorflow as tf
868+
869+
return tf.random.Generator.from_seed(seed)
870+
else:
871+
import numpy as np
872+
873+
return np.random.default_rng(seed)
874+
875+
def create_rng(self):
876+
if mp.current_process().name == "MainProcess":
877+
# If we are in the main process, we create a generator object with the seed
878+
self.generator = self.get_generator(self.seed)
879+
else:
880+
# If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each
881+
# worker's generator, generated as the main seed + the worker's ID.
882+
# (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
883+
# Only PyTorch DataLoader allows us to access the worker ID, and so we check for this.
884+
# For other frameworks, we will throw an error.
885+
import torch
886+
887+
worker_info = torch.utils.data.get_worker_info()
888+
if worker_info is None:
889+
error_string = (
890+
"Worker process information is not available for seeding the generator. This may be because",
891+
"you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can",
892+
"only be used when using multiprocessing with a PyTorch DataLoader. Please either use a",
893+
"single process or use a PyTorch DataLoader with multiple workers.",
894+
)
895+
raise ValueError(error_string)
896+
897+
self.generator = self.get_generator(self.seed + worker_info.id)
898+
855899
@staticmethod
856-
def tf_bernoulli(shape, probability):
900+
def tf_bernoulli(shape, probability, generator=None):
857901
import tensorflow as tf
858902

859903
prob_matrix = tf.fill(shape, probability)
860-
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
904+
# if generator exists, use it to generate the random numbers
905+
# otherwise, use the global RNG
906+
if generator:
907+
return tf.cast(prob_matrix - generator.uniform(shape, 0, 1) >= 0, tf.bool)
908+
else:
909+
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
861910

862911
def tf_mask_tokens(
863912
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
@@ -872,12 +921,12 @@ def tf_mask_tokens(
872921
input_shape = tf.shape(inputs)
873922
# 1 for a special token, 0 for a normal token in the special tokens mask
874923
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
875-
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
924+
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability, self.generator) & ~special_tokens_mask
876925
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
877926
labels = tf.where(masked_indices, inputs, -100)
878927

879928
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
880-
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices
929+
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices
881930

882931
inputs = tf.where(indices_replaced, mask_token_id, inputs)
883932

@@ -891,9 +940,15 @@ def tf_mask_tokens(
891940
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
892941
# random_replace_prob% of the time, we replace masked input tokens with random word
893942
indices_random = (
894-
self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
943+
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
944+
& masked_indices
945+
& ~indices_replaced
895946
)
896-
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
947+
948+
if self.generator:
949+
random_words = self.generator.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
950+
else:
951+
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
897952

898953
inputs = tf.where(indices_random, random_words, inputs)
899954

@@ -903,6 +958,11 @@ def tf_mask_tokens(
903958
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
904959
import tensorflow as tf
905960

961+
if self.seed and self.generator is None:
962+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
963+
# If no seed supplied, we will use the global RNG
964+
self.create_rng()
965+
906966
# Handle dict or lists with proper padding and conversion to tensor.
907967
if isinstance(examples[0], Mapping):
908968
batch = pad_without_fast_tokenizer_warning(
@@ -943,6 +1003,12 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
9431003

9441004
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
9451005
# Handle dict or lists with proper padding and conversion to tensor.
1006+
1007+
if self.seed and self.generator is None:
1008+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1009+
# If no seed supplied, we will use the global RNG
1010+
self.create_rng()
1011+
9461012
if isinstance(examples[0], Mapping):
9471013
batch = pad_without_fast_tokenizer_warning(
9481014
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
@@ -983,11 +1049,14 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
9831049
special_tokens_mask = special_tokens_mask.bool()
9841050

9851051
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
986-
masked_indices = torch.bernoulli(probability_matrix).bool()
1052+
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
9871053
labels[~masked_indices] = -100 # We only compute loss on masked tokens
9881054

9891055
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
990-
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
1056+
indices_replaced = (
1057+
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
1058+
& masked_indices
1059+
)
9911060
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
9921061

9931062
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
@@ -1001,18 +1070,24 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10011070

10021071
# random_replace_prob% of the time, we replace masked input tokens with random word
10031072
indices_random = (
1004-
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool()
1073+
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
10051074
& masked_indices
10061075
& ~indices_replaced
10071076
)
1008-
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1077+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
10091078
inputs[indices_random] = random_words[indices_random]
10101079

10111080
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
10121081
return inputs, labels
10131082

10141083
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
10151084
# Handle dict or lists with proper padding and conversion to tensor.
1085+
1086+
if self.seed and self.generator is None:
1087+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1088+
# If no seed supplied, we will use the global RNG
1089+
self.create_rng()
1090+
10161091
if isinstance(examples[0], Mapping):
10171092
batch = pad_without_fast_tokenizer_warning(
10181093
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
@@ -1052,13 +1127,21 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10521127

10531128
probability_matrix[special_tokens_mask] = 0
10541129
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
1055-
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
1130+
if self.generator:
1131+
masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
1132+
else:
1133+
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
10561134
labels[~masked_indices] = -100 # We only compute loss on masked tokens
10571135

10581136
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1059-
indices_replaced = (
1060-
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
1061-
)
1137+
if self.generator:
1138+
indices_replaced = (
1139+
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
1140+
)
1141+
else:
1142+
indices_replaced = (
1143+
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
1144+
)
10621145
inputs[indices_replaced] = self.tokenizer.mask_token_id
10631146

10641147
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
@@ -1069,14 +1152,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10691152
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
10701153
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
10711154
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
1072-
indices_random = (
1073-
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
1074-
& masked_indices
1075-
& ~indices_replaced
1076-
)
1077-
random_words = np.random.randint(
1078-
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
1079-
)
1155+
if self.generator:
1156+
indices_random = (
1157+
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
1158+
& masked_indices
1159+
& ~indices_replaced
1160+
)
1161+
random_words = self.generator.integers(
1162+
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
1163+
)
1164+
else:
1165+
indices_random = (
1166+
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
1167+
& masked_indices
1168+
& ~indices_replaced
1169+
)
1170+
random_words = np.random.randint(
1171+
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
1172+
)
10801173
inputs[indices_random] = random_words
10811174

10821175
# The rest of the time (10% of the time) we keep the masked input tokens unchanged

0 commit comments

Comments
 (0)