diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0415f942cf14..23e9a6fdcb99 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -114,6 +114,8 @@ title: Logging - local: main_classes/model title: Models + - local: main_classes/text_generation + title: Text Generation - local: main_classes/onnx title: ONNX - local: main_classes/optimizer_schedules diff --git a/docs/source/main_classes/model.mdx b/docs/source/main_classes/model.mdx index d65ae8516e32..4da5e72b7ed1 100644 --- a/docs/source/main_classes/model.mdx +++ b/docs/source/main_classes/model.mdx @@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes. - push_to_hub - all -## Generation - -[[autodoc]] generation_utils.GenerationMixin - -[[autodoc]] generation_tf_utils.TFGenerationMixin - -[[autodoc]] generation_flax_utils.FlaxGenerationMixin - ## Pushing to the Hub [[autodoc]] file_utils.PushToHubMixin diff --git a/docs/source/main_classes/text_generation.mdx b/docs/source/main_classes/text_generation.mdx new file mode 100644 index 000000000000..509dfe750ad8 --- /dev/null +++ b/docs/source/main_classes/text_generation.mdx @@ -0,0 +1,39 @@ + + +# Generation + +The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively. + +The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all +methods for auto-regressive text generation to every model class. + +## GenerationMixn + +[[autodoc]] generation_utils.GenerationMixin + - generate + - greedy_search + - sample + - beam_search + - beam_sample + - group_beam_search + - constrained_beam_search + +## TFGenerationMixn + +[[autodoc]] generation_tf_utils.TFGenerationMixin + - generate + +## FlaxGenerationMixn + +[[autodoc]] generation_flax_utils.FlaxGenerationMixin + - generate diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index a9f76d738e96..2bc6db2f56dd 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -118,7 +118,16 @@ class BeamSearchState: class FlaxGenerationMixin: """ - A class containing all of the functions supporting generation, to be used as a mixin in [`FlaxPreTrainedModel`]. + A class containing all functions for auto-regressive text generation, to be used as a mixin in + [`FlaxPreTrainedModel`]. + + The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if + `num_beams=1` and `do_sample=False`. + - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1` + and `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1` + and `do_sample=False`. """ @staticmethod @@ -176,12 +185,23 @@ def generate( **model_kwargs, ): r""" - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - and, multinomial sampling. + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - Apart from `input_ids`, all the arguments below will default to the value of the attribute of the same name - inside the [`PretrainedConfig`] of the model. The default values indicated are the default values of those - config. + - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if + `num_beams=1` and `do_sample=False`. + - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1` + and `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1` + and `do_sample=False`. + + + + Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as + defined in the model's config (`config.json`) which in turn defaults to the + [`~modeling_utils.PretrainedConfig`] of the model. + + Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). @@ -236,7 +256,7 @@ def generate( >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids >>> # generate candidates using sampling >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" # set init values max_length = max_length if max_length is not None else self.config.max_length diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index d9a901d201d9..85bbc51e6f23 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -377,7 +377,21 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): class GenerationMixin: """ - A class containing all of the functions supporting generation, to be used as a mixin in [`PreTrainedModel`]. + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + + The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], + if `constraints!=None` or `force_words_ids!=None`. """ def _prepare_model_inputs( @@ -847,18 +861,37 @@ def generate( **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - multinomial sampling, beam-search decoding, and beam-search multinomial sampling. - Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside - the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config. + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling + [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or + `force_words_ids!=None`. + + + + Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as + defined in the model's config (`config.json`) which in turn defaults to the + [`~modeling_utils.PretrainedConfig`] of the model. + + Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). Parameters: - inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length, - feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*): + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of @@ -997,66 +1030,56 @@ def generate( Examples: + Greedy Decoding: + ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM + >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> # do greedy decoding without providing a prompt - >>> outputs = model.generate(max_length=40) - >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - >>> document = ( - ... "at least two people were killed in a suspected bomb attack on a passenger bus " - ... "in the strife-torn southern philippines on monday , the military said." - ... ) - >>> # encode input context - >>> input_ids = tokenizer(document, return_tensors="pt").input_ids - >>> # generate 3 independent sequences using beam search decoding (5 beams) - >>> # with T5 encoder-decoder model conditioned on short news article. - >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) - - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> input_context = "The dog" - >>> # encode input context - >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids - >>> # generate 3 candidates using sampling - >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) - - >>> tokenizer = AutoTokenizer.from_pretrained("ctrl") - >>> model = AutoModelForCausalLM.from_pretrained("ctrl") - >>> # "Legal" is one of the control codes for ctrl - >>> input_context = "Legal My neighbor is" - >>> # encode input context - >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids - >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2) - >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial Sampling: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> input_context = "My cute dog" - >>> # get tokens of words that should not be generated - >>> bad_words_ids = tokenizer( - ... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False - >>> ).input_ids - >>> # get tokens of words that we want generated - >>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids - >>> # encode input context - >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids - >>> # generate sequences without allowing bad_words to be generated - >>> outputs = model.generate( - ... input_ids=input_ids, - ... max_length=20, - ... do_sample=True, - ... bad_words_ids=bad_words_ids, - ... force_words_ids=force_words_ids, - ... ) - >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # sample up to 30 tokens + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] + ``` + + Beam-search decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + + >>> outputs = model.generate(input_ids) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id @@ -1457,7 +1480,8 @@ def greedy_search( **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head using greedy decoding. + Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be + used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -1508,6 +1532,8 @@ def greedy_search( ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") @@ -1516,26 +1542,30 @@ def greedy_search( >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id - >>> input_prompt = "Today is a beautiful day, and" + >>> input_prompt = "It might be possible to" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ - ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), ... ] ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) + >>> outputs = model.greedy_search( + ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + ... ) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) @@ -1683,7 +1713,8 @@ def sample( **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head using multinomial sampling. + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -1739,7 +1770,10 @@ def sample( ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, ... ) + >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -1764,9 +1798,18 @@ def sample( ... ] ... ) - >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values @@ -1926,7 +1969,8 @@ def beam_search( **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head using beam search decoding. + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -2020,7 +2064,8 @@ def beam_search( >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -2237,7 +2282,8 @@ def beam_sample( **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head using beam search with multinomial sampling. + Generates sequences of token ids for models with a language modeling head using **beam search multinomial + sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -2343,7 +2389,8 @@ def beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -2556,7 +2603,8 @@ def group_beam_search( **model_kwargs, ): r""" - Generates sequences for models with a language modeling head using beam search decoding. + Generates sequences of token ids for models with a language modeling head using **diverse beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: @@ -2656,7 +2704,8 @@ def group_beam_search( ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... ) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -2920,7 +2969,8 @@ def constrained_beam_search( ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" - Generates sequences for models with a language modeling head using beam search decoding. + Generates sequences of token ids for models with a language modeling head using **constrained beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -3024,8 +3074,8 @@ def constrained_beam_search( ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... ) - >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) - # => ['Wie alter sind Sie?'] + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt sind Sie?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 7c15c26f0705..1bbba630c201 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/plbart/modeling_plbart.py +src/transformers/generation_utils.py docs/source/quicktour.mdx -docs/source/task_summary.mdx \ No newline at end of file +docs/source/task_summary.mdx