Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp
import random
import warnings
from collections.abc import Mapping
Expand Down Expand Up @@ -787,6 +788,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
If set, will pad the sequence to a multiple of the provided value.
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
seed (`int`, *optional*):
The seed to use for the random number generator for masking. If not provided, the global RNG will be used.

<Tip>

Expand Down Expand Up @@ -827,6 +830,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
pad_to_multiple_of: Optional[int] = None
tf_experimental_compile: bool = False
return_tensors: str = "pt"
seed: Optional[int] = None

def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
Expand All @@ -852,12 +856,57 @@ def __post_init__(self):

self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)

self.generator = None

def get_generator(self, seed):
if self.return_tensors == "pt":
import torch

return torch.Generator().manual_seed(seed)
elif self.return_tensors == "tf":
import tensorflow as tf

return tf.random.Generator.from_seed(seed)
else:
import numpy as np

return np.random.default_rng(seed)

def create_rng(self):
if mp.current_process().name == "MainProcess":
# If we are in the main process, we create a generator object with the seed
self.generator = self.get_generator(self.seed)
else:
# If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each
# worker's generator, generated as the main seed + the worker's ID.
# (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
# Only PyTorch DataLoader allows us to access the worker ID, and so we check for this.
# For other frameworks, we will throw an error.
import torch

worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
error_string = (
"Worker process information is not available for seeding the generator. This may be because",
"you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can",
"only be used when using multiprocessing with a PyTorch DataLoader. Please either use a",
"single process or use a PyTorch DataLoader with multiple workers.",
)
raise ValueError(error_string)

self.generator = self.get_generator(self.seed + worker_info.id)

@staticmethod
def tf_bernoulli(shape, probability):
def tf_bernoulli(shape, probability, generator=None):
import tensorflow as tf

prob_matrix = tf.fill(shape, probability)
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
# if generator exists, use it to generate the random numbers
# otherwise, use the global RNG
if generator:
return tf.cast(prob_matrix - generator.uniform(shape, 0, 1) >= 0, tf.bool)
else:
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)

def tf_mask_tokens(
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
Expand All @@ -872,12 +921,12 @@ def tf_mask_tokens(
input_shape = tf.shape(inputs)
# 1 for a special token, 0 for a normal token in the special tokens mask
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability, self.generator) & ~special_tokens_mask
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices

inputs = tf.where(indices_replaced, mask_token_id, inputs)

Expand All @@ -891,9 +940,15 @@ def tf_mask_tokens(
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = (
self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
& masked_indices
& ~indices_replaced
)
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)

if self.generator:
random_words = self.generator.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
else:
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)

inputs = tf.where(indices_random, random_words, inputs)

Expand All @@ -903,6 +958,11 @@ def tf_mask_tokens(
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
import tensorflow as tf

if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], Mapping):
batch = pad_without_fast_tokenizer_warning(
Expand Down Expand Up @@ -943,6 +1003,12 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.

if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

if isinstance(examples[0], Mapping):
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
Expand Down Expand Up @@ -983,11 +1049,14 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
special_tokens_mask = special_tokens_mask.bool()

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
indices_replaced = (
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
& masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
Expand All @@ -1001,18 +1070,24 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No

# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool()
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
inputs[indices_random] = random_words[indices_random]

# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.

if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

if isinstance(examples[0], Mapping):
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
Expand Down Expand Up @@ -1052,13 +1127,21 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No

probability_matrix[special_tokens_mask] = 0
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
if self.generator:
masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
else:
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
if self.generator:
indices_replaced = (
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
else:
indices_replaced = (
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
inputs[indices_replaced] = self.tokenizer.mask_token_id

if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
Expand All @@ -1069,14 +1152,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
indices_random = (
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
if self.generator:
indices_random = (
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = self.generator.integers(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
else:
indices_random = (
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
inputs[indices_random] = random_words

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
Expand Down
Loading