Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,39 +74,37 @@
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
"""
LARGE_NEGATIVE = -1e8


def invert_mask(attention_mask):
assert attention_mask.dim() == 2
return attention_mask.eq(0)


def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
"""Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
pad_token_id = config.pad_token_id
need_causal_mask = not config.output_past
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()[:2]
if decoder_attn_mask is None:
bsz, tgt_len = decoder_input_ids.size()
if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
if need_causal_mask:
causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1)
else:
causal_lm_mask = None
new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
if mask_dtype is not None:
decoder_attn_mask = decoder_attn_mask.to(mask_dtype)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask


class PretrainedBartModel(PreTrainedModel):
Expand All @@ -130,12 +128,9 @@ def _init_weights(self, module):
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_attention_mask": decoder_attn_mask,
}
return dummy_inputs

Expand All @@ -153,21 +148,6 @@ def _check_shapes(shape_1, shape2):
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))


def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
"""Make one mask of shape (bsz, 1, tgt_len, src_len) """
a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len)
b = torch.zeros(targ_size)
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
_check_shapes(key_padding_mask.shape, targ_size[:2])
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
a[reshaped] = LARGE_NEGATIVE

if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)


def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
Expand Down Expand Up @@ -281,8 +261,7 @@ def forward(
"""
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = attention_mask.eq(0)
attention_mask = invert_mask(attention_mask)

inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids)
Expand Down Expand Up @@ -339,15 +318,26 @@ def __init__(self, config: BartConfig):
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(
self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None,
self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
):
residual = x

if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn(
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
query=x,
key=x,
layer_state=layer_state,
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
Expand Down Expand Up @@ -412,7 +402,8 @@ def forward(
input_ids,
encoder_hidden_states,
encoder_padding_mask,
combined_mask,
decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None,
generation_mode=False,
**unused
Expand All @@ -437,8 +428,7 @@ def forward(
"""
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = encoder_padding_mask.eq(0)
encoder_padding_mask = invert_mask(encoder_padding_mask)

# embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
Expand All @@ -458,7 +448,6 @@ def forward(
all_hidden_states = ()
all_self_attns = ()
next_decoder_cache = []

for i, decoder_layer in enumerate(self.layers):
decoder_layer # type: DecoderLayer
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Expand All @@ -468,7 +457,12 @@ def forward(

layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask,
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
)

if self.output_past:
Expand Down Expand Up @@ -736,6 +730,8 @@ def _filter_out_falsey_values(tup) -> Tuple:


# Public API
def _get_shape(t):
return getattr(t, "shape", None)


@add_start_docstrings(
Expand Down Expand Up @@ -769,13 +765,16 @@ def forward(

# make masks if user doesn't supply
if not generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.weight.dtype,
decoder_padding_mask=decoder_attention_mask,
causal_mask_dtype=self.shared.weight.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None

assert decoder_input_ids is not None
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
Expand All @@ -785,7 +784,8 @@ def forward(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
Expand Down
53 changes: 21 additions & 32 deletions tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
shift_tokens_right,
invert_mask,
_prepare_bart_decoder_inputs,
LARGE_NEGATIVE,
)
from transformers.tokenization_bart import BartTokenizer

Expand Down Expand Up @@ -123,10 +123,9 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

def test_advanced_inputs(self):
def test_initialization_more(self):
# (config, input_ids, token_type_ids, input_mask, *unused) = \
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"])
model = BartModel(config)
model.to(torch_device)
model.eval()
Expand All @@ -142,9 +141,17 @@ def _check_var(module):
_check_var(model.encoder.layers[0].fc1)
_check_var(model.encoder.embed_positions)

def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"]
)
model = BartModel(config).to(torch_device).eval()

decoder_features_with_created_mask = model(**inputs_dict)[0]
decoder_features_with_passed_mask = model(
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
)[0]
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
useless_mask = torch.zeros_like(decoder_attn_mask)
Expand Down Expand Up @@ -238,7 +245,7 @@ def test_lm_forward(self):
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertIsInstance(loss.item(), float)
Expand Down Expand Up @@ -336,41 +343,23 @@ def test_default_generate_kwargs(self):
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

def test_dummy_inputs(self):
config, *_ = self._get_config_and_data(output_past=True)
config, *_ = self._get_config_and_data()
model = BartForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs)

def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False)
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = LARGE_NEGATIVE
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
expected_mask = torch.tensor(
[
[0, ignore, ignore],
[0, 0, ignore],
[ignore, ignore, ignore], # never attend to the final token, because its pad
]
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())

# Test no causal mask
config, *_ = self._get_config_and_data(output_past=True)
expected_just_padding_mask = torch.tensor(
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
).to(input_ids.device)
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())

decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
# Attend to everything if no pad tokens and no causal mask
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
ignore = float("-inf")
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids
)
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())

def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data()
Expand Down