4242)
4343from .generation_stopping_criteria import (
4444 MaxLengthCriteria ,
45+ MaxNewTokensCriteria ,
4546 MaxTimeCriteria ,
4647 StoppingCriteriaList ,
4748 validate_stopping_criteria ,
@@ -628,15 +629,15 @@ def _get_logits_processor(
628629 return processors
629630
630631 def _get_stopping_criteria (
631- self ,
632- max_length : Optional [int ],
633- max_time : Optional [float ],
632+ self , max_length : Optional [int ], max_time : Optional [float ], max_new_tokens : Optional [int ], start_length : int
634633 ) -> StoppingCriteriaList :
635634 stopping_criteria = StoppingCriteriaList ()
636635 if max_length is not None :
637636 stopping_criteria .append (MaxLengthCriteria (max_length = max_length ))
638637 if max_time is not None :
639638 stopping_criteria .append (MaxTimeCriteria (max_time = max_time ))
639+ if max_new_tokens is not None :
640+ stopping_criteria .append (MaxNewTokensCriteria (start_length = start_length , max_new_tokens = max_new_tokens ))
640641 return stopping_criteria
641642
642643 @torch .no_grad ()
@@ -661,6 +662,7 @@ def generate(
661662 encoder_no_repeat_ngram_size : Optional [int ] = None ,
662663 num_return_sequences : Optional [int ] = None ,
663664 max_time : Optional [float ] = None ,
665+ max_new_tokens : Optional [int ] = None ,
664666 decoder_start_token_id : Optional [int ] = None ,
665667 use_cache : Optional [bool ] = None ,
666668 num_beam_groups : Optional [int ] = None ,
@@ -692,8 +694,11 @@ def generate(
692694 input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693695 The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
694696 :obj:`torch.LongTensor` of shape :obj:`(1,)`.
695- max_length (:obj:`int`, `optional`, defaults to 20 ):
697+ max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length` ):
696698 The maximum length of the sequence to be generated.
699+ max_new_tokens (:obj:`int`, `optional`, defaults to None):
700+ The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
701+ :obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose.
697702 min_length (:obj:`int`, `optional`, defaults to 10):
698703 The minimum length of the sequence to be generated.
699704 do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -861,6 +866,15 @@ def generate(
861866 """
862867
863868 # set init values
869+ if max_length is None and max_new_tokens is None :
870+ # Both are None, default
871+ max_length = self .config .max_length
872+ elif max_length is not None and max_new_tokens is not None :
873+ # Both are set, this is odd, raise a warning
874+ warnings .warn (
875+ "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose." , UserWarning
876+ )
877+
864878 max_length = max_length if max_length is not None else self .config .max_length
865879 num_beams = num_beams if num_beams is not None else self .config .num_beams
866880 num_beam_groups = num_beam_groups if num_beam_groups is not None else self .config .num_beam_groups
@@ -960,7 +974,10 @@ def generate(
960974 remove_invalid_values = remove_invalid_values ,
961975 )
962976
963- stopping_criteria = self ._get_stopping_criteria (max_length = max_length , max_time = max_time )
977+ cur_len = input_ids .shape [- 1 ]
978+ stopping_criteria = self ._get_stopping_criteria (
979+ max_length = max_length , max_time = max_time , max_new_tokens = max_new_tokens , start_length = cur_len
980+ )
964981
965982 if is_greedy_gen_mode :
966983 if num_return_sequences > 1 :
0 commit comments