diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 0415f942cf14..23e9a6fdcb99 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -114,6 +114,8 @@
title: Logging
- local: main_classes/model
title: Models
+ - local: main_classes/text_generation
+ title: Text Generation
- local: main_classes/onnx
title: ONNX
- local: main_classes/optimizer_schedules
diff --git a/docs/source/main_classes/model.mdx b/docs/source/main_classes/model.mdx
index d65ae8516e32..4da5e72b7ed1 100644
--- a/docs/source/main_classes/model.mdx
+++ b/docs/source/main_classes/model.mdx
@@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes.
- push_to_hub
- all
-## Generation
-
-[[autodoc]] generation_utils.GenerationMixin
-
-[[autodoc]] generation_tf_utils.TFGenerationMixin
-
-[[autodoc]] generation_flax_utils.FlaxGenerationMixin
-
## Pushing to the Hub
[[autodoc]] file_utils.PushToHubMixin
diff --git a/docs/source/main_classes/text_generation.mdx b/docs/source/main_classes/text_generation.mdx
new file mode 100644
index 000000000000..509dfe750ad8
--- /dev/null
+++ b/docs/source/main_classes/text_generation.mdx
@@ -0,0 +1,39 @@
+
+
+# Generation
+
+The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively.
+
+The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all
+methods for auto-regressive text generation to every model class.
+
+## GenerationMixn
+
+[[autodoc]] generation_utils.GenerationMixin
+ - generate
+ - greedy_search
+ - sample
+ - beam_search
+ - beam_sample
+ - group_beam_search
+ - constrained_beam_search
+
+## TFGenerationMixn
+
+[[autodoc]] generation_tf_utils.TFGenerationMixin
+ - generate
+
+## FlaxGenerationMixn
+
+[[autodoc]] generation_flax_utils.FlaxGenerationMixin
+ - generate
diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py
index a9f76d738e96..2bc6db2f56dd 100644
--- a/src/transformers/generation_flax_utils.py
+++ b/src/transformers/generation_flax_utils.py
@@ -118,7 +118,16 @@ class BeamSearchState:
class FlaxGenerationMixin:
"""
- A class containing all of the functions supporting generation, to be used as a mixin in [`FlaxPreTrainedModel`].
+ A class containing all functions for auto-regressive text generation, to be used as a mixin in
+ [`FlaxPreTrainedModel`].
+
+ The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
+ - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
+ `num_beams=1` and `do_sample=False`.
+ - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
+ and `do_sample=True`.
+ - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
+ and `do_sample=False`.
"""
@staticmethod
@@ -176,12 +185,23 @@ def generate(
**model_kwargs,
):
r"""
- Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
- and, multinomial sampling.
+ Generates sequences of token ids for models with a language modeling head. The method supports the following
+ generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- Apart from `input_ids`, all the arguments below will default to the value of the attribute of the same name
- inside the [`PretrainedConfig`] of the model. The default values indicated are the default values of those
- config.
+ - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
+ `num_beams=1` and `do_sample=False`.
+ - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
+ and `do_sample=True`.
+ - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
+ and `do_sample=False`.
+
+
+
+ Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
+ defined in the model's config (`config.json`) which in turn defaults to the
+ [`~modeling_utils.PretrainedConfig`] of the model.
+
+
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
@@ -236,7 +256,7 @@ def generate(
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py
index d9a901d201d9..85bbc51e6f23 100644
--- a/src/transformers/generation_utils.py
+++ b/src/transformers/generation_utils.py
@@ -377,7 +377,21 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
class GenerationMixin:
"""
- A class containing all of the functions supporting generation, to be used as a mixin in [`PreTrainedModel`].
+ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
+
+ The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
+ - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
+ `do_sample=False`.
+ - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
+ `do_sample=True`.
+ - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
+ `do_sample=False`.
+ - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
+ `num_beams>1` and `do_sample=True`.
+ - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
+ `num_beams>1` and `num_beam_groups>1`.
+ - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
+ if `constraints!=None` or `force_words_ids!=None`.
"""
def _prepare_model_inputs(
@@ -847,18 +861,37 @@ def generate(
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
- multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
- Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside
- the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config.
+ Generates sequences of token ids for models with a language modeling head. The method supports the following
+ generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
+
+ - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
+ `do_sample=False`.
+ - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
+ `do_sample=True`.
+ - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
+ `do_sample=False`.
+ - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
+ `num_beams>1` and `do_sample=True`.
+ - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
+ `num_beams>1` and `num_beam_groups>1`.
+ - *constrained beam-search decoding* by calling
+ [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
+ `force_words_ids!=None`.
+
+
+
+ Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
+ defined in the model's config (`config.json`) which in turn defaults to the
+ [`~modeling_utils.PretrainedConfig`] of the model.
+
+
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Parameters:
- inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
- feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
@@ -997,66 +1030,56 @@ def generate(
Examples:
+ Greedy Decoding:
+
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
- >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
- >>> # do greedy decoding without providing a prompt
- >>> outputs = model.generate(max_length=40)
- >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
- >>> document = (
- ... "at least two people were killed in a suspected bomb attack on a passenger bus "
- ... "in the strife-torn southern philippines on monday , the military said."
- ... )
- >>> # encode input context
- >>> input_ids = tokenizer(document, return_tensors="pt").input_ids
- >>> # generate 3 independent sequences using beam search decoding (5 beams)
- >>> # with T5 encoder-decoder model conditioned on short news article.
- >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
-
- >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
- >>> input_context = "The dog"
- >>> # encode input context
- >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
- >>> # generate 3 candidates using sampling
- >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
-
- >>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
- >>> model = AutoModelForCausalLM.from_pretrained("ctrl")
- >>> # "Legal" is one of the control codes for ctrl
- >>> input_context = "Legal My neighbor is"
- >>> # encode input context
- >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
- >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
- >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
+ >>> prompt = "Today I believe we can finally"
+ >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ >>> # generate up to 30 tokens
+ >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
+ ```
+
+ Multinomial Sampling:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
- >>> input_context = "My cute dog"
- >>> # get tokens of words that should not be generated
- >>> bad_words_ids = tokenizer(
- ... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
- >>> ).input_ids
- >>> # get tokens of words that we want generated
- >>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids
- >>> # encode input context
- >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
- >>> # generate sequences without allowing bad_words to be generated
- >>> outputs = model.generate(
- ... input_ids=input_ids,
- ... max_length=20,
- ... do_sample=True,
- ... bad_words_ids=bad_words_ids,
- ... force_words_ids=force_words_ids,
- ... )
- >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
+
+ >>> prompt = "Today I believe we can finally"
+ >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ >>> # sample up to 30 tokens
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
+ >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
+ ```
+
+ Beam-search decoding:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
+
+ >>> sentence = "Paris is one of the densest populated areas in Europe."
+ >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
+
+ >>> outputs = model.generate(input_ids)
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
@@ -1457,7 +1480,8 @@ def greedy_search(
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head using greedy decoding.
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
+ used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@@ -1508,6 +1532,8 @@ def greedy_search(
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -1516,26 +1542,30 @@ def greedy_search(
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
- >>> input_prompt = "Today is a beautiful day, and"
+ >>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
- ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
+ ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
... ]
... )
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
- >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
+ >>> outputs = model.greedy_search(
+ ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
+ ... )
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
- "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+ "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
@@ -1683,7 +1713,8 @@ def sample(
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head using multinomial sampling.
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@@ -1739,7 +1770,10 @@ def sample(
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
... )
+ >>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
@@ -1764,9 +1798,18 @@ def sample(
... ]
... )
- >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
+ >>> outputs = model.sample(
+ ... input_ids,
+ ... logits_processor=logits_processor,
+ ... logits_warper=logits_warper,
+ ... stopping_criteria=stopping_criteria,
+ ... )
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```"""
# init values
@@ -1926,7 +1969,8 @@ def beam_search(
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head using beam search decoding.
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@@ -2020,7 +2064,8 @@ def beam_search(
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@@ -2237,7 +2282,8 @@ def beam_sample(
**model_kwargs,
) -> Union[BeamSampleOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head using beam search with multinomial sampling.
+ Generates sequences of token ids for models with a language modeling head using **beam search multinomial
+ sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@@ -2343,7 +2389,8 @@ def beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@@ -2556,7 +2603,8 @@ def group_beam_search(
**model_kwargs,
):
r"""
- Generates sequences for models with a language modeling head using beam search decoding.
+ Generates sequences of token ids for models with a language modeling head using **diverse beam search
+ decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@@ -2656,7 +2704,8 @@ def group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@@ -2920,7 +2969,8 @@ def constrained_beam_search(
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
- Generates sequences for models with a language modeling head using beam search decoding.
+ Generates sequences of token ids for models with a language modeling head using **constrained beam search
+ decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -3024,8 +3074,8 @@ def constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
- >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
- # => ['Wie alter sind Sie?']
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Wie alt sind Sie?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 7c15c26f0705..1bbba630c201 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/plbart/modeling_plbart.py
+src/transformers/generation_utils.py
docs/source/quicktour.mdx
-docs/source/task_summary.mdx
\ No newline at end of file
+docs/source/task_summary.mdx