1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import multiprocessing as mp
1516import random
1617import warnings
1718from collections .abc import Mapping
@@ -787,6 +788,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
787788 If set, will pad the sequence to a multiple of the provided value.
788789 return_tensors (`str`):
789790 The type of Tensor to return. Allowable values are "np", "pt" and "tf".
791+ seed (`int`, *optional*):
792+ The seed to use for the random number generator for masking. If not provided, the global RNG will be used.
790793
791794 <Tip>
792795
@@ -827,6 +830,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
827830 pad_to_multiple_of : Optional [int ] = None
828831 tf_experimental_compile : bool = False
829832 return_tensors : str = "pt"
833+ seed : Optional [int ] = None
830834
831835 def __post_init__ (self ):
832836 if self .mlm and self .tokenizer .mask_token is None :
@@ -852,12 +856,57 @@ def __post_init__(self):
852856
853857 self .tf_mask_tokens = tf .function (self .tf_mask_tokens , jit_compile = True )
854858
859+ self .generator = None
860+
861+ def get_generator (self , seed ):
862+ if self .return_tensors == "pt" :
863+ import torch
864+
865+ return torch .Generator ().manual_seed (seed )
866+ elif self .return_tensors == "tf" :
867+ import tensorflow as tf
868+
869+ return tf .random .Generator .from_seed (seed )
870+ else :
871+ import numpy as np
872+
873+ return np .random .default_rng (seed )
874+
875+ def create_rng (self ):
876+ if mp .current_process ().name == "MainProcess" :
877+ # If we are in the main process, we create a generator object with the seed
878+ self .generator = self .get_generator (self .seed )
879+ else :
880+ # If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each
881+ # worker's generator, generated as the main seed + the worker's ID.
882+ # (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
883+ # Only PyTorch DataLoader allows us to access the worker ID, and so we check for this.
884+ # For other frameworks, we will throw an error.
885+ import torch
886+
887+ worker_info = torch .utils .data .get_worker_info ()
888+ if worker_info is None :
889+ error_string = (
890+ "Worker process information is not available for seeding the generator. This may be because" ,
891+ "you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can" ,
892+ "only be used when using multiprocessing with a PyTorch DataLoader. Please either use a" ,
893+ "single process or use a PyTorch DataLoader with multiple workers." ,
894+ )
895+ raise ValueError (error_string )
896+
897+ self .generator = self .get_generator (self .seed + worker_info .id )
898+
855899 @staticmethod
856- def tf_bernoulli (shape , probability ):
900+ def tf_bernoulli (shape , probability , generator = None ):
857901 import tensorflow as tf
858902
859903 prob_matrix = tf .fill (shape , probability )
860- return tf .cast (prob_matrix - tf .random .uniform (shape , 0 , 1 ) >= 0 , tf .bool )
904+ # if generator exists, use it to generate the random numbers
905+ # otherwise, use the global RNG
906+ if generator :
907+ return tf .cast (prob_matrix - generator .uniform (shape , 0 , 1 ) >= 0 , tf .bool )
908+ else :
909+ return tf .cast (prob_matrix - tf .random .uniform (shape , 0 , 1 ) >= 0 , tf .bool )
861910
862911 def tf_mask_tokens (
863912 self , inputs : Any , vocab_size , mask_token_id , special_tokens_mask : Optional [Any ] = None
@@ -872,12 +921,12 @@ def tf_mask_tokens(
872921 input_shape = tf .shape (inputs )
873922 # 1 for a special token, 0 for a normal token in the special tokens mask
874923 # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
875- masked_indices = self .tf_bernoulli (input_shape , self .mlm_probability ) & ~ special_tokens_mask
924+ masked_indices = self .tf_bernoulli (input_shape , self .mlm_probability , self . generator ) & ~ special_tokens_mask
876925 # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
877926 labels = tf .where (masked_indices , inputs , - 100 )
878927
879928 # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
880- indices_replaced = self .tf_bernoulli (input_shape , self .mask_replace_prob ) & masked_indices
929+ indices_replaced = self .tf_bernoulli (input_shape , self .mask_replace_prob , self . generator ) & masked_indices
881930
882931 inputs = tf .where (indices_replaced , mask_token_id , inputs )
883932
@@ -891,9 +940,15 @@ def tf_mask_tokens(
891940 random_replace_prob_scaled = self .random_replace_prob / remaining_prob
892941 # random_replace_prob% of the time, we replace masked input tokens with random word
893942 indices_random = (
894- self .tf_bernoulli (input_shape , random_replace_prob_scaled ) & masked_indices & ~ indices_replaced
943+ self .tf_bernoulli (input_shape , random_replace_prob_scaled , self .generator )
944+ & masked_indices
945+ & ~ indices_replaced
895946 )
896- random_words = tf .random .uniform (input_shape , maxval = vocab_size , dtype = inputs .dtype )
947+
948+ if self .generator :
949+ random_words = self .generator .uniform (input_shape , maxval = vocab_size , dtype = inputs .dtype )
950+ else :
951+ random_words = tf .random .uniform (input_shape , maxval = vocab_size , dtype = inputs .dtype )
897952
898953 inputs = tf .where (indices_random , random_words , inputs )
899954
@@ -903,6 +958,11 @@ def tf_mask_tokens(
903958 def tf_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
904959 import tensorflow as tf
905960
961+ if self .seed and self .generator is None :
962+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
963+ # If no seed supplied, we will use the global RNG
964+ self .create_rng ()
965+
906966 # Handle dict or lists with proper padding and conversion to tensor.
907967 if isinstance (examples [0 ], Mapping ):
908968 batch = pad_without_fast_tokenizer_warning (
@@ -943,6 +1003,12 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
9431003
9441004 def torch_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
9451005 # Handle dict or lists with proper padding and conversion to tensor.
1006+
1007+ if self .seed and self .generator is None :
1008+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1009+ # If no seed supplied, we will use the global RNG
1010+ self .create_rng ()
1011+
9461012 if isinstance (examples [0 ], Mapping ):
9471013 batch = pad_without_fast_tokenizer_warning (
9481014 self .tokenizer , examples , return_tensors = "pt" , pad_to_multiple_of = self .pad_to_multiple_of
@@ -983,11 +1049,14 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
9831049 special_tokens_mask = special_tokens_mask .bool ()
9841050
9851051 probability_matrix .masked_fill_ (special_tokens_mask , value = 0.0 )
986- masked_indices = torch .bernoulli (probability_matrix ).bool ()
1052+ masked_indices = torch .bernoulli (probability_matrix , generator = self . generator ).bool ()
9871053 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
9881054
9891055 # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
990- indices_replaced = torch .bernoulli (torch .full (labels .shape , self .mask_replace_prob )).bool () & masked_indices
1056+ indices_replaced = (
1057+ torch .bernoulli (torch .full (labels .shape , self .mask_replace_prob ), generator = self .generator ).bool ()
1058+ & masked_indices
1059+ )
9911060 inputs [indices_replaced ] = self .tokenizer .convert_tokens_to_ids (self .tokenizer .mask_token )
9921061
9931062 if self .mask_replace_prob == 1 or self .random_replace_prob == 0 :
@@ -1001,18 +1070,24 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10011070
10021071 # random_replace_prob% of the time, we replace masked input tokens with random word
10031072 indices_random = (
1004- torch .bernoulli (torch .full (labels .shape , random_replace_prob_scaled )).bool ()
1073+ torch .bernoulli (torch .full (labels .shape , random_replace_prob_scaled ), generator = self . generator ).bool ()
10051074 & masked_indices
10061075 & ~ indices_replaced
10071076 )
1008- random_words = torch .randint (len (self .tokenizer ), labels .shape , dtype = torch .long )
1077+ random_words = torch .randint (len (self .tokenizer ), labels .shape , dtype = torch .long , generator = self . generator )
10091078 inputs [indices_random ] = random_words [indices_random ]
10101079
10111080 # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
10121081 return inputs , labels
10131082
10141083 def numpy_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
10151084 # Handle dict or lists with proper padding and conversion to tensor.
1085+
1086+ if self .seed and self .generator is None :
1087+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1088+ # If no seed supplied, we will use the global RNG
1089+ self .create_rng ()
1090+
10161091 if isinstance (examples [0 ], Mapping ):
10171092 batch = pad_without_fast_tokenizer_warning (
10181093 self .tokenizer , examples , return_tensors = "np" , pad_to_multiple_of = self .pad_to_multiple_of
@@ -1052,13 +1127,21 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10521127
10531128 probability_matrix [special_tokens_mask ] = 0
10541129 # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
1055- masked_indices = np .random .binomial (1 , probability_matrix , size = probability_matrix .shape ).astype (bool )
1130+ if self .generator :
1131+ masked_indices = self .generator .binomial (1 , probability_matrix , size = probability_matrix .shape ).astype (bool )
1132+ else :
1133+ masked_indices = np .random .binomial (1 , probability_matrix , size = probability_matrix .shape ).astype (bool )
10561134 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
10571135
10581136 # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1059- indices_replaced = (
1060- np .random .binomial (1 , self .mask_replace_prob , size = labels .shape ).astype (bool ) & masked_indices
1061- )
1137+ if self .generator :
1138+ indices_replaced = (
1139+ self .generator .binomial (1 , self .mask_replace_prob , size = labels .shape ).astype (bool ) & masked_indices
1140+ )
1141+ else :
1142+ indices_replaced = (
1143+ np .random .binomial (1 , self .mask_replace_prob , size = labels .shape ).astype (bool ) & masked_indices
1144+ )
10621145 inputs [indices_replaced ] = self .tokenizer .mask_token_id
10631146
10641147 if self .mask_replace_prob == 1 or self .random_replace_prob == 0 :
@@ -1069,14 +1152,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
10691152 # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
10701153 # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
10711154 random_replace_prob_scaled = self .random_replace_prob / remaining_prob
1072- indices_random = (
1073- np .random .binomial (1 , random_replace_prob_scaled , size = labels .shape ).astype (bool )
1074- & masked_indices
1075- & ~ indices_replaced
1076- )
1077- random_words = np .random .randint (
1078- low = 0 , high = len (self .tokenizer ), size = np .count_nonzero (indices_random ), dtype = np .int64
1079- )
1155+ if self .generator :
1156+ indices_random = (
1157+ self .generator .binomial (1 , random_replace_prob_scaled , size = labels .shape ).astype (bool )
1158+ & masked_indices
1159+ & ~ indices_replaced
1160+ )
1161+ random_words = self .generator .integers (
1162+ low = 0 , high = len (self .tokenizer ), size = np .count_nonzero (indices_random ), dtype = np .int64
1163+ )
1164+ else :
1165+ indices_random = (
1166+ np .random .binomial (1 , random_replace_prob_scaled , size = labels .shape ).astype (bool )
1167+ & masked_indices
1168+ & ~ indices_replaced
1169+ )
1170+ random_words = np .random .randint (
1171+ low = 0 , high = len (self .tokenizer ), size = np .count_nonzero (indices_random ), dtype = np .int64
1172+ )
10801173 inputs [indices_random ] = random_words
10811174
10821175 # The rest of the time (10% of the time) we keep the masked input tokens unchanged
0 commit comments