Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions docs/source/model_doc/t5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,45 @@ To facilitate future work on transfer learning for NLP, we release our dataset,

The Authors' code can be found `here <https://github.com/google-research/text-to-text-transfer-transformer>`_ .

Training
~~~~~~~~~~~~~~~~~~~~
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher forcing.
This means that for training we always need an input sequence and a target sequence.
The input sequence is fed to the model using ``input_ids``. The target sequence is shifted to the right, *i.e.* perprended by a start-sequence token and fed to the decoder using the `decoder_input_ids`. In teacher-forcing style, the target sequence is then appended by the EOS token and corresponds to the ``lm_labels``. The PAD token is hereby used as the start-sequence token.
T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.

- Unsupervised denoising training
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens)
and the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens.
Each sentinel tokens represents a unique mask token for this sentence and should start with ``<extra_id_1>``, ``<extrac_id_2>``, ... up to ``<extra_id_100>``. As a default 100 sentinel tokens are available in ``T5Tokenizer``.
*E.g.* the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be processed as follows:

::

input_ids = tokenizer.encode('The <extra_id_1> walks in <extra_id_2> park', return_tensors='pt')
lm_labels = tokenizer.encode('<extra_id_1> cute dog <extra_id_2> the <extra_id_3> </s>', return_tensors='pt')
# the forward function automatically creates the correct decoder_input_ids
model(input_ids=input_ids, lm_labels=lm_labels)

- Supervised training
In this setup the input sequence and output sequence are standard sequence to sequence input output mapping.
In translation, *e.g.* the input sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar." should
be processed as follows:

::

input_ids = tokenizer.encode('translate English to German: The house is wonderful. </s>', return_tensors='pt')
lm_labels = tokenizer.encode('Das Haus ist wunderbar. </s>', return_tensors='pt')
# the forward function automatically creates the correct decoder_input_ids
model(input_ids=input_ids, lm_labels=lm_labels)

Tips
~~~~~~~~~~~~~~~~~~~~
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised
and supervised tasks and which each task is cast as a sequence to sequence task.
Therefore T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g.: for translation: *translate English to German: ..., summarize: ...*.
For more information about the which prefix to use, it is easiest to look into Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`_ .
- For sequence to sequence generation, it is recommended to use ``T5ForConditionalGeneration.generate()``. The method takes care of feeding the encoded input via cross-attention layers to the decoder and auto-regressively generating the decoder output.
and supervised tasks and for which each task is converted into a text-to-text format.
T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g.: for translation: *translate English to German: ..., summarize: ...*.
For more information about which prefix to use, it is easiest to look into Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`_ .
- For sequence to sequence generation, it is recommended to use ``T5ForConditionalGeneration.generate()``. The method takes care of feeding the encoded input via cross-attention layers to the decoder and auto-regressively generates the decoder output.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
Mask to avoid performing attention on padding token indices in input_ids.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
Expand Down
18 changes: 5 additions & 13 deletions src/transformers/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,32 +699,24 @@ def forward(
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:

(a) For sequence pairs:

``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``

(b) For single sequences:

``tokens: [CLS] the dog is hairy . [SEP]``

T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.

Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
To know more on how to prepare :obj:`input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Expand Down
20 changes: 6 additions & 14 deletions src/transformers/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,27 +637,18 @@ def dummy_inputs(self):

input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:

(a) For sequence pairs:

``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``

(b) For single sequences:

``tokens: [CLS] the dog is hairy . [SEP]``

T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.

Indices can be obtained using :class:`transformers.T5Tokenizer`.
To know more on how to prepare :obj:`input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`):
encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
Expand All @@ -671,6 +662,8 @@ def dummy_inputs(self):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Expand Down Expand Up @@ -836,7 +829,7 @@ def call(self, decoder_input_ids, **kwargs):
model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
outputs = model(input_ids, input_ids=input_ids, lm_labels=input_ids)
prediction_scores = outputs[:1] # TODO: TFT5 still needs to implement
prediction_scores = outputs[:1]

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
Expand Down Expand Up @@ -879,7 +872,6 @@ def call(self, decoder_input_ids, **kwargs):
head_mask=head_mask,
)

# TODO (thom / patrick): add lm_labels for loss function
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
embed_tokens = self.get_output_embeddings()
lm_logits = embed_tokens(sequence_output, mode="linear")
Expand Down