From 3f9723d880a23e737754ed787600891936ef630c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Aug 2023 10:07:06 +0000 Subject: [PATCH 1/7] Add missing logits processors docs --- docs/source/en/internal/generation_utils.md | 60 +++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index f5c882bf128a..e159923a171f 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -129,6 +129,9 @@ generation. [[autodoc]] RepetitionPenaltyLogitsProcessor - __call__ +[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor + - __call__ + [[autodoc]] TopPLogitsWarper - __call__ @@ -138,9 +141,18 @@ generation. [[autodoc]] TypicalLogitsWarper - __call__ +[[autodoc]] EpsilonLogitsWarper + - __call__ + +[[autodoc]] EtaLogitsWarper + - __call__ + [[autodoc]] NoRepeatNGramLogitsProcessor - __call__ +[[autodoc]] EncoderNoRepeatNGramLogitsProcessor + - __call__ + [[autodoc]] SequenceBiasLogitsProcessor - __call__ @@ -162,6 +174,33 @@ generation. [[autodoc]] InfNanRemoveLogitsProcessor - __call__ +[[autodoc]] ExponentialDecayLengthPenalty + - __call__ + +[[autodoc]] LogitNormalization + - __call__ + +[[autodoc]] SuppressTokensAtBeginLogitsProcessor + - __call__ + +[[autodoc]] SuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] ForceTokensLogitsProcessor + - __call__ + +[[autodoc]] WhisperTimeStampLogitsProcessor + - __call__ + +[[autodoc]] ClassifierFreeGuidanceLogitsProcessor + - __call__ + +[[autodoc]] AlternatingCodebooksLogitsProcessor + - __call__ + +[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor + - __call__ + [[autodoc]] TFLogitsProcessor - __call__ @@ -198,6 +237,15 @@ generation. [[autodoc]] TFForcedEOSTokenLogitsProcessor - __call__ +[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor + - __call__ + +[[autodoc]] TFSuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] TFForceTokensLogitsProcessor + - __call__ + [[autodoc]] FlaxLogitsProcessor - __call__ @@ -225,6 +273,18 @@ generation. [[autodoc]] FlaxMinLengthLogitsProcessor - __call__ +[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor + - __call__ + +[[autodoc]] FlaxSuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] FlaxForceTokensLogitsProcessor + - __call__ + +[[autodoc]] FlaxWhisperTimeStampLogitsProcessor + - __call__ + ## StoppingCriteria A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). From 5b371ede42c5116d27f1fcbdf3c4a8c2cdcd2b72 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Aug 2023 10:59:18 +0000 Subject: [PATCH 2/7] add missing imports --- src/transformers/__init__.py | 28 +++++++++++ src/transformers/generation/__init__.py | 22 ++++++--- src/transformers/utils/dummy_flax_objects.py | 28 +++++++++++ src/transformers/utils/dummy_pt_objects.py | 49 ++++++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 21 +++++++++ 5 files changed, 141 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9b95aadffccc..f19bc99b4f0b 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1011,11 +1011,17 @@ "Constraint", "ConstraintListState", "DisjunctiveConstraint", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", + "EpsilonLogitsWarper", + "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", "GenerationMixin", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -1035,6 +1041,7 @@ "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", "top_k_top_p_filtering", ] ) @@ -3115,6 +3122,7 @@ [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFGenerationMixin", "TFLogitsProcessor", "TFLogitsProcessorList", @@ -3123,6 +3131,8 @@ "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", @@ -3836,14 +3846,18 @@ [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxGenerationMixin", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] ) _import_structure["generation_flax_utils"] = [] @@ -4988,11 +5002,17 @@ Constraint, ConstraintListState, DisjunctiveConstraint, + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, GenerationMixin, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, + LogitNormalization, LogitsProcessor, LogitsProcessorList, LogitsWarper, @@ -5012,6 +5032,7 @@ TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, top_k_top_p_filtering, ) from .modeling_utils import PreTrainedModel @@ -6711,6 +6732,7 @@ from .generation import ( TFForcedBOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor, + TFForceTokensLogitsProcessor, TFGenerationMixin, TFLogitsProcessor, TFLogitsProcessorList, @@ -6719,6 +6741,8 @@ TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFSuppressTokensAtBeginLogitsProcessor, + TFSuppressTokensLogitsProcessor, TFTemperatureLogitsWarper, TFTopKLogitsWarper, TFTopPLogitsWarper, @@ -7284,14 +7308,18 @@ from .generation import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxGenerationMixin, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .modeling_flax_utils import FlaxPreTrainedModel diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index f0da9f514e7a..2bfdd529f932 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -41,12 +41,16 @@ "ConstrainedBeamSearchScorer", ] _import_structure["logits_process"] = [ + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", "EpsilonLogitsWarper", "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -57,14 +61,10 @@ "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", "SequenceBiasLogitsProcessor", - "EncoderRepetitionPenaltyLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", - "EncoderNoRepeatNGramLogitsProcessor", - "ExponentialDecayLengthPenalty", - "LogitNormalization", "UnbatchedClassifierFreeGuidanceLogitsProcessor", ] _import_structure["stopping_criteria"] = [ @@ -99,6 +99,7 @@ _import_structure["tf_logits_process"] = [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFLogitsProcessor", "TFLogitsProcessorList", "TFLogitsWarper", @@ -106,12 +107,11 @@ "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", - "TFForceTokensLogitsProcessor", - "TFSuppressTokensAtBeginLogitsProcessor", - "TFSuppressTokensLogitsProcessor", ] _import_structure["tf_utils"] = [ "TFGenerationMixin", @@ -137,13 +137,17 @@ _import_structure["flax_logits_process"] = [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] _import_structure["flax_utils"] = [ "FlaxGenerationMixin", @@ -261,13 +265,17 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput else: diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 78be4ef747e9..0f6af902ec26 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -16,6 +16,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxGenerationMixin(metaclass=DummyObject): _backends = ["flax"] @@ -51,6 +58,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["flax"] @@ -72,6 +93,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPreTrainedModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5724e689f2fc..f0035f6ec78d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -121,6 +121,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EpsilonLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EtaLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ExponentialDecayLengthPenalty(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -156,6 +191,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LogitNormalization(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -289,6 +331,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + def top_k_top_p_filtering(*args, **kwargs): requires_backends(top_k_top_p_filtering, ["torch"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 46cde8ffbef4..9b1aae449326 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -30,6 +30,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFGenerationMixin(metaclass=DummyObject): _backends = ["tf"] @@ -86,6 +93,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["tf"] From 14d9750e92f5c62762086c4f63335dd6b320716a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Aug 2023 11:44:25 +0000 Subject: [PATCH 3/7] more missing imports --- src/transformers/__init__.py | 2 ++ src/transformers/generation/__init__.py | 2 ++ src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 3 files changed, 11 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f19bc99b4f0b..30bdd11211a6 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1037,6 +1037,7 @@ "SequenceBiasLogitsProcessor", "StoppingCriteria", "StoppingCriteriaList", + "SuppressTokensAtBeginLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", @@ -5028,6 +5029,7 @@ SequenceBiasLogitsProcessor, StoppingCriteria, StoppingCriteriaList, + SuppressTokensAtBeginLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 2bfdd529f932..67fe1b4a9970 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -61,6 +61,7 @@ "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", "SequenceBiasLogitsProcessor", + "SuppressTokensAtBeginLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", @@ -189,6 +190,7 @@ PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f0035f6ec78d..fcedc449d90a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -303,6 +303,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] From 6e04a9d2bb77c67b22c01533fbd5627bc57fbb03 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Aug 2023 12:12:42 +0000 Subject: [PATCH 4/7] more missing imports --- src/transformers/__init__.py | 10 +++++++ src/transformers/generation/__init__.py | 10 +++++++ src/transformers/utils/dummy_pt_objects.py | 35 ++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 30bdd11211a6..2a05f7676508 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1005,8 +1005,10 @@ _import_structure["deepspeed"] = [] _import_structure["generation"].extend( [ + "AlternatingCodebooksLogitsProcessor", "BeamScorer", "BeamSearchScorer", + "ClassifierFreeGuidanceLogitsProcessor", "ConstrainedBeamSearchScorer", "Constraint", "ConstraintListState", @@ -1018,6 +1020,7 @@ "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "GenerationMixin", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", @@ -1038,11 +1041,13 @@ "StoppingCriteria", "StoppingCriteriaList", "SuppressTokensAtBeginLogitsProcessor", + "SuppressTokensLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", "top_k_top_p_filtering", ] ) @@ -4997,8 +5002,10 @@ TextDatasetForNextSentencePrediction, ) from .generation import ( + AlternatingCodebooksLogitsProcessor, BeamScorer, BeamSearchScorer, + ClassifierFreeGuidanceLogitsProcessor, ConstrainedBeamSearchScorer, Constraint, ConstraintListState, @@ -5010,6 +5017,7 @@ ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, GenerationMixin, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, @@ -5030,11 +5038,13 @@ StoppingCriteria, StoppingCriteriaList, SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, top_k_top_p_filtering, ) from .modeling_utils import PreTrainedModel diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 67fe1b4a9970..e2a7f1d2bc82 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -41,6 +41,8 @@ "ConstrainedBeamSearchScorer", ] _import_structure["logits_process"] = [ + "AlternatingCodebooksLogitsProcessor", + "ClassifierFreeGuidanceLogitsProcessor", "EncoderNoRepeatNGramLogitsProcessor", "EncoderRepetitionPenaltyLogitsProcessor", "EpsilonLogitsWarper", @@ -48,6 +50,7 @@ "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", "LogitNormalization", @@ -61,12 +64,14 @@ "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", "SequenceBiasLogitsProcessor", + "SuppressTokensLogitsProcessor", "SuppressTokensAtBeginLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", ] _import_structure["stopping_criteria"] = [ "MaxNewTokensCriteria", @@ -170,6 +175,8 @@ from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .logits_process import ( + AlternatingCodebooksLogitsProcessor, + ClassifierFreeGuidanceLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, @@ -177,6 +184,7 @@ ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, @@ -191,11 +199,13 @@ RepetitionPenaltyLogitsProcessor, SequenceBiasLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( MaxLengthCriteria, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fcedc449d90a..c1cdc3955e97 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -79,6 +79,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeamScorer(metaclass=DummyObject): _backends = ["torch"] @@ -93,6 +100,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ConstrainedBeamSearchScorer(metaclass=DummyObject): _backends = ["torch"] @@ -170,6 +184,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GenerationMixin(metaclass=DummyObject): _backends = ["torch"] @@ -310,6 +331,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] @@ -345,6 +373,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + def top_k_top_p_filtering(*args, **kwargs): requires_backends(top_k_top_p_filtering, ["torch"]) From 807fa5d75a4b20502c64339c09ba1fb62e9871c7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Aug 2023 17:46:31 +0000 Subject: [PATCH 5/7] tmp commit --- docs/source/en/internal/generation_utils.md | 35 +++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index e159923a171f..d8dc555822a9 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -75,39 +75,44 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`. We document here all output types. -### GreedySearchOutput +### PyTorch [[autodoc]] generation.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GreedySearchEncoderDecoderOutput -[[autodoc]] generation.FlaxGreedySearchOutput - -### SampleOutput - [[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.SampleEncoderDecoderOutput -[[autodoc]] generation.FlaxSampleOutput - -### BeamSearchOutput - [[autodoc]] generation.BeamSearchDecoderOnlyOutput [[autodoc]] generation.BeamSearchEncoderDecoderOutput -### BeamSampleOutput - [[autodoc]] generation.BeamSampleDecoderOnlyOutput [[autodoc]] generation.BeamSampleEncoderDecoderOutput + +### TensorFlow + + + +### FLAX + +[[autodoc]] generation.FlaxSampleOutput + +[[autodoc]] generation.FlaxGreedySearchOutput + + + ## LogitsProcessor A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for generation. +### PyTorch + [[autodoc]] LogitsProcessor - __call__ @@ -201,6 +206,8 @@ generation. [[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor - __call__ +### TensorFlow + [[autodoc]] TFLogitsProcessor - __call__ @@ -246,6 +253,8 @@ generation. [[autodoc]] TFForceTokensLogitsProcessor - __call__ +### FLAX + [[autodoc]] FlaxLogitsProcessor - __call__ @@ -287,7 +296,7 @@ generation. ## StoppingCriteria -A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). +A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] StoppingCriteria - __call__ @@ -303,7 +312,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ## Constraints -A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. +A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] Constraint From 697f8b89d6cc7a98e8ae21a6bb19807f5f065b5d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Aug 2023 18:01:24 +0000 Subject: [PATCH 6/7] sort docs by alphabetical order; Add missing output classes; Separate TOC by framework --- docs/source/en/internal/generation_utils.md | 137 +++++++++++--------- src/transformers/generation/__init__.py | 2 +- 2 files changed, 80 insertions(+), 59 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index d8dc555822a9..1e797ce10bef 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -77,26 +77,47 @@ We document here all output types. ### PyTorch +[[autodoc]] generation.GreedySearchEncoderDecoderOutput + [[autodoc]] generation.GreedySearchDecoderOnlyOutput -[[autodoc]] generation.GreedySearchEncoderDecoderOutput +[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.SampleDecoderOnlyOutput -[[autodoc]] generation.SampleEncoderDecoderOutput +[[autodoc]] generation.BeamSearchEncoderDecoderOutput [[autodoc]] generation.BeamSearchDecoderOnlyOutput -[[autodoc]] generation.BeamSearchEncoderDecoderOutput +[[autodoc]] generation.BeamSampleEncoderDecoderOutput [[autodoc]] generation.BeamSampleDecoderOnlyOutput -[[autodoc]] generation.BeamSampleEncoderDecoderOutput +[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput +[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput ### TensorFlow +[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput + +[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput + +[[autodoc]] generation.TFSampleEncoderDecoderOutput + +[[autodoc]] generation.TFSampleDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput +[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput + +[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput ### FLAX @@ -104,7 +125,7 @@ We document here all output types. [[autodoc]] generation.FlaxGreedySearchOutput - +[[autodoc]] generation.FlaxBeamSearchSearchOutput ## LogitsProcessor @@ -113,76 +134,73 @@ generation. ### PyTorch -[[autodoc]] LogitsProcessor - - __call__ - -[[autodoc]] LogitsProcessorList +[[autodoc]] AlternatingCodebooksLogitsProcessor - __call__ -[[autodoc]] LogitsWarper +[[autodoc]] ClassifierFreeGuidanceLogitsProcessor - __call__ -[[autodoc]] MinLengthLogitsProcessor +[[autodoc]] EncoderNoRepeatNGramLogitsProcessor - __call__ -[[autodoc]] MinNewTokensLengthLogitsProcessor +[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] TemperatureLogitsWarper +[[autodoc]] EpsilonLogitsWarper - __call__ -[[autodoc]] RepetitionPenaltyLogitsProcessor +[[autodoc]] EtaLogitsWarper - __call__ -[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor +[[autodoc]] ExponentialDecayLengthPenalty - __call__ -[[autodoc]] TopPLogitsWarper +[[autodoc]] ForcedBOSTokenLogitsProcessor - __call__ -[[autodoc]] TopKLogitsWarper +[[autodoc]] ForcedEOSTokenLogitsProcessor - __call__ -[[autodoc]] TypicalLogitsWarper +[[autodoc]] ForceTokensLogitsProcessor - __call__ -[[autodoc]] EpsilonLogitsWarper +[[autodoc]] HammingDiversityLogitsProcessor - __call__ -[[autodoc]] EtaLogitsWarper +[[autodoc]] InfNanRemoveLogitsProcessor - __call__ -[[autodoc]] NoRepeatNGramLogitsProcessor +[[autodoc]] LogitNormalization - __call__ -[[autodoc]] EncoderNoRepeatNGramLogitsProcessor +[[autodoc]] LogitsProcessor - __call__ -[[autodoc]] SequenceBiasLogitsProcessor +[[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] NoBadWordsLogitsProcessor +[[autodoc]] LogitsWarper - __call__ -[[autodoc]] PrefixConstrainedLogitsProcessor +[[autodoc]] MinLengthLogitsProcessor - __call__ -[[autodoc]] HammingDiversityLogitsProcessor +[[autodoc]] MinNewTokensLengthLogitsProcessor - __call__ -[[autodoc]] ForcedBOSTokenLogitsProcessor +[[autodoc]] NoBadWordsLogitsProcessor - __call__ -[[autodoc]] ForcedEOSTokenLogitsProcessor +[[autodoc]] NoRepeatNGramLogitsProcessor - __call__ -[[autodoc]] InfNanRemoveLogitsProcessor +[[autodoc]] PrefixConstrainedLogitsProcessor - __call__ -[[autodoc]] ExponentialDecayLengthPenalty +[[autodoc]] RepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] LogitNormalization +[[autodoc]] SequenceBiasLogitsProcessor - __call__ [[autodoc]] SuppressTokensAtBeginLogitsProcessor @@ -191,39 +209,42 @@ generation. [[autodoc]] SuppressTokensLogitsProcessor - __call__ -[[autodoc]] ForceTokensLogitsProcessor +[[autodoc]] TemperatureLogitsWarper - __call__ -[[autodoc]] WhisperTimeStampLogitsProcessor +[[autodoc]] TopKLogitsWarper - __call__ -[[autodoc]] ClassifierFreeGuidanceLogitsProcessor +[[autodoc]] TopPLogitsWarper - __call__ -[[autodoc]] AlternatingCodebooksLogitsProcessor +[[autodoc]] TypicalLogitsWarper - __call__ [[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor - __call__ +[[autodoc]] WhisperTimeStampLogitsProcessor + - __call__ + ### TensorFlow -[[autodoc]] TFLogitsProcessor +[[autodoc]] TFForcedBOSTokenLogitsProcessor - __call__ -[[autodoc]] TFLogitsProcessorList +[[autodoc]] TFForcedEOSTokenLogitsProcessor - __call__ -[[autodoc]] TFLogitsWarper +[[autodoc]] TFForceTokensLogitsProcessor - __call__ -[[autodoc]] TFTemperatureLogitsWarper +[[autodoc]] TFLogitsProcessor - __call__ -[[autodoc]] TFTopPLogitsWarper +[[autodoc]] TFLogitsProcessorList - __call__ -[[autodoc]] TFTopKLogitsWarper +[[autodoc]] TFLogitsWarper - __call__ [[autodoc]] TFMinLengthLogitsProcessor @@ -238,57 +259,57 @@ generation. [[autodoc]] TFRepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] TFForcedBOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] TFForcedEOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensLogitsProcessor - __call__ -[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor +[[autodoc]] TFTemperatureLogitsWarper - __call__ -[[autodoc]] TFSuppressTokensLogitsProcessor +[[autodoc]] TFTopKLogitsWarper - __call__ -[[autodoc]] TFForceTokensLogitsProcessor +[[autodoc]] TFTopPLogitsWarper - __call__ ### FLAX -[[autodoc]] FlaxLogitsProcessor +[[autodoc]] FlaxForcedBOSTokenLogitsProcessor - __call__ -[[autodoc]] FlaxLogitsProcessorList +[[autodoc]] FlaxForcedEOSTokenLogitsProcessor - __call__ -[[autodoc]] FlaxLogitsWarper +[[autodoc]] FlaxForceTokensLogitsProcessor - __call__ -[[autodoc]] FlaxTemperatureLogitsWarper +[[autodoc]] FlaxLogitsProcessor - __call__ -[[autodoc]] FlaxTopPLogitsWarper +[[autodoc]] FlaxLogitsProcessorList - __call__ -[[autodoc]] FlaxTopKLogitsWarper +[[autodoc]] FlaxLogitsWarper - __call__ -[[autodoc]] FlaxForcedBOSTokenLogitsProcessor +[[autodoc]] FlaxMinLengthLogitsProcessor - __call__ -[[autodoc]] FlaxForcedEOSTokenLogitsProcessor +[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] FlaxMinLengthLogitsProcessor +[[autodoc]] FlaxSuppressTokensLogitsProcessor - __call__ -[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor +[[autodoc]] FlaxTemperatureLogitsWarper - __call__ -[[autodoc]] FlaxSuppressTokensLogitsProcessor +[[autodoc]] FlaxTopKLogitsWarper - __call__ -[[autodoc]] FlaxForceTokensLogitsProcessor +[[autodoc]] FlaxTopPLogitsWarper - __call__ [[autodoc]] FlaxWhisperTimeStampLogitsProcessor diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index e2a7f1d2bc82..a46cb4fa910a 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -148,9 +148,9 @@ "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", - "FlaxTemperatureLogitsWarper", "FlaxSuppressTokensAtBeginLogitsProcessor", "FlaxSuppressTokensLogitsProcessor", + "FlaxTemperatureLogitsWarper", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", "FlaxWhisperTimeStampLogitsProcessor", From e7c75a048b82461e5e092eb88c983c9983a4d0d5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Aug 2023 18:20:49 +0000 Subject: [PATCH 7/7] typo --- docs/source/en/internal/generation_utils.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1e797ce10bef..906ee4ea620b 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -125,7 +125,7 @@ We document here all output types. [[autodoc]] generation.FlaxGreedySearchOutput -[[autodoc]] generation.FlaxBeamSearchSearchOutput +[[autodoc]] generation.FlaxBeamSearchOutput ## LogitsProcessor