Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 80d712f

Browse files
authored
Adding new argument max_new_tokens for generate. (huggingface#11476)
* Adding new argument `max_new_tokens` for generate. This is a proposal to add a new argument `max_new_tokens` to `generate`. This include a `MaxNewTokensCriteria` that enables callers that don't know about the token length ahead (like pipelines callers) to manage more easily the length of their generated output. * Adding a test for the user warning when both`max_length` and `max_new_tokens` are used together. * Removed redundant `no_grad`.
1 parent 2dd6fb2 commit 80d712f

File tree

4 files changed

+86
-5
lines changed

4 files changed

+86
-5
lines changed

src/transformers/generation_stopping_criteria.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,29 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
5757
return input_ids.shape[-1] >= self.max_length
5858

5959

60+
class MaxNewTokensCriteria(StoppingCriteria):
61+
"""
62+
This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`.
63+
Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is
64+
very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens.
65+
66+
Args:
67+
start_length (:obj:`int`):
68+
The number of initial tokens.
69+
max_new_tokens (:obj:`int`):
70+
The maximum number of tokens to generate.
71+
"""
72+
73+
def __init__(self, start_length: int, max_new_tokens: int):
74+
self.start_length = start_length
75+
self.max_new_tokens = max_new_tokens
76+
self.max_length = start_length + max_new_tokens
77+
78+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
79+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
80+
return input_ids.shape[-1] >= self.max_length
81+
82+
6083
class MaxTimeCriteria(StoppingCriteria):
6184
"""
6285
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
@@ -89,6 +112,8 @@ def max_length(self) -> Optional[int]:
89112
for stopping_criterium in self:
90113
if isinstance(stopping_criterium, MaxLengthCriteria):
91114
return stopping_criterium.max_length
115+
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
116+
return stopping_criterium.max_length
92117
return None
93118

94119

src/transformers/generation_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from .generation_stopping_criteria import (
4444
MaxLengthCriteria,
45+
MaxNewTokensCriteria,
4546
MaxTimeCriteria,
4647
StoppingCriteriaList,
4748
validate_stopping_criteria,
@@ -628,15 +629,15 @@ def _get_logits_processor(
628629
return processors
629630

630631
def _get_stopping_criteria(
631-
self,
632-
max_length: Optional[int],
633-
max_time: Optional[float],
632+
self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int
634633
) -> StoppingCriteriaList:
635634
stopping_criteria = StoppingCriteriaList()
636635
if max_length is not None:
637636
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
638637
if max_time is not None:
639638
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
639+
if max_new_tokens is not None:
640+
stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens))
640641
return stopping_criteria
641642

642643
@torch.no_grad()
@@ -661,6 +662,7 @@ def generate(
661662
encoder_no_repeat_ngram_size: Optional[int] = None,
662663
num_return_sequences: Optional[int] = None,
663664
max_time: Optional[float] = None,
665+
max_new_tokens: Optional[int] = None,
664666
decoder_start_token_id: Optional[int] = None,
665667
use_cache: Optional[bool] = None,
666668
num_beam_groups: Optional[int] = None,
@@ -692,8 +694,11 @@ def generate(
692694
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693695
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
694696
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
695-
max_length (:obj:`int`, `optional`, defaults to 20):
697+
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
696698
The maximum length of the sequence to be generated.
699+
max_new_tokens (:obj:`int`, `optional`, defaults to None):
700+
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
701+
:obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose.
697702
min_length (:obj:`int`, `optional`, defaults to 10):
698703
The minimum length of the sequence to be generated.
699704
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -861,6 +866,15 @@ def generate(
861866
"""
862867

863868
# set init values
869+
if max_length is None and max_new_tokens is None:
870+
# Both are None, default
871+
max_length = self.config.max_length
872+
elif max_length is not None and max_new_tokens is not None:
873+
# Both are set, this is odd, raise a warning
874+
warnings.warn(
875+
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning
876+
)
877+
864878
max_length = max_length if max_length is not None else self.config.max_length
865879
num_beams = num_beams if num_beams is not None else self.config.num_beams
866880
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
@@ -960,7 +974,10 @@ def generate(
960974
remove_invalid_values=remove_invalid_values,
961975
)
962976

963-
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
977+
cur_len = input_ids.shape[-1]
978+
stopping_criteria = self._get_stopping_criteria(
979+
max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len
980+
)
964981

965982
if is_greedy_gen_mode:
966983
if num_return_sequences > 1:

tests/test_generation_stopping_criteria.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from transformers.generation_stopping_criteria import (
1414
MaxLengthCriteria,
15+
MaxNewTokensCriteria,
1516
MaxTimeCriteria,
1617
StoppingCriteriaList,
1718
validate_stopping_criteria,
@@ -58,6 +59,21 @@ def test_max_length_criteria(self):
5859
input_ids, scores = self._get_tensors(10)
5960
self.assertTrue(criteria(input_ids, scores))
6061

62+
def test_max_new_tokens_criteria(self):
63+
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
64+
65+
input_ids, scores = self._get_tensors(5)
66+
self.assertFalse(criteria(input_ids, scores))
67+
68+
input_ids, scores = self._get_tensors(9)
69+
self.assertFalse(criteria(input_ids, scores))
70+
71+
input_ids, scores = self._get_tensors(10)
72+
self.assertTrue(criteria(input_ids, scores))
73+
74+
criteria_list = StoppingCriteriaList([criteria])
75+
self.assertEqual(criteria_list.max_length, 10)
76+
6177
def test_max_time_criteria(self):
6278
input_ids, scores = self._get_tensors(5)
6379

tests/test_generation_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,3 +1615,26 @@ def test_beam_search_warning_if_max_length_is_passed(self):
16151615

16161616
# BeamSearchScorer max_length should not influence "real" max_length
16171617
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
1618+
1619+
def test_max_new_tokens(self):
1620+
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
1621+
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
1622+
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
1623+
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
1624+
1625+
self.assertEqual(list(input_ids.shape), [1, 15])
1626+
1627+
# Encoder decoder call
1628+
max_new_tokens = 3
1629+
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
1630+
# 1 BOS + 3 new tokens
1631+
self.assertEqual(list(outputs.shape), [1, 4])
1632+
1633+
# Decoder only call
1634+
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
1635+
# 15 + 3 new tokens
1636+
self.assertEqual(list(outputs.shape), [1, 18])
1637+
1638+
# max_new_tokens and max_length serve the same purpose and should not be used together.
1639+
with self.assertWarns(UserWarning):
1640+
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)

0 commit comments

Comments
 (0)