diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ba5d0f0005a6..9689ca2b5203 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -129,9 +129,9 @@ def __init__( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) - # Remove potential default "num_logits_to_keep" key - if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep(): - del assistant_kwargs["num_logits_to_keep"] + # Remove potential default "logits_to_keep" key + if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep(): + del assistant_kwargs["logits_to_keep"] if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 655a388cb70d..1c77dbdbb90c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1813,12 +1813,12 @@ def _prepare_cache_for_generation( else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) - def _supports_num_logits_to_keep(self) -> bool: + def _supports_logits_to_keep(self) -> bool: """ - Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() + Return True if the current model supports the keyword argument `logits_to_keep` in forward() to save memory. Checking it in this way allows to avoid using a new model attribute. """ - return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) + return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) def _prepare_special_tokens( self, @@ -2099,11 +2099,11 @@ def generate( input_ids_length=input_ids_length, ) - # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: - model_kwargs["num_logits_to_keep"] = 1 + if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: + model_kwargs["logits_to_keep"] = 1 self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -4269,8 +4269,8 @@ def _assisted_decoding( ) model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) - if "num_logits_to_keep" in model_inputs: - model_inputs["num_logits_to_keep"] = candidate_length + 1 + if "logits_to_keep" in model_inputs: + model_inputs["logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence # prepare variable output controls (note: some models won't accept all output controls) @@ -4608,7 +4608,7 @@ def _split_model_inputs( # ModelOutput object. # bool should not be split but replicated for each split bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"] + keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] num_hidden_layers = config.get_text_config().num_hidden_layers @@ -4628,10 +4628,10 @@ def _split_model_inputs( data_split_list = [ {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) ] - # num_logits_to_keep should be replicated for each split, similar to bool values - if "num_logits_to_keep" in model_input: + # logits_to_keep should be replicated for each split, similar to bool values + if "logits_to_keep" in model_input: data_split_list = [ - {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list + {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list ] # Convert each dictionary in the list to an object of the inferred class diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3a9b044c1168..61901c0eda0f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1292,6 +1292,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + # This flag signal that the model can be used as an efficient backend in TGI and vLLM + # In practice, it means that they support attention interface functions, fully pass the kwargs + # through all modules up to the Attention layer, and can slice logits with Tensor + _supports_attention_backend = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -5187,6 +5192,10 @@ def get_compiled_call(self, compile_config: CompileConfig): self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict()) return self._compiled_call + @classmethod + def is_backend_compatible(cls): + return cls._supports_attention_backend + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0b330b4aeeda..414301673552 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -37,6 +37,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import AutoModel, AutoModelForCausalLM from .configuration_aria import AriaConfig, AriaTextConfig @@ -708,6 +709,7 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = False def _init_weights(self, module): std = self.config.initializer_range @@ -1168,6 +1170,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1183,7 +1186,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1193,10 +1196,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1239,7 +1244,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1324,8 +1330,9 @@ class AriaCausalLMOutputWithPast(ModelOutput): Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a `ModelOutput` object. - num_logits_to_keep (`int`, *optional*, defaults to 0): - Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`. + Otherwise, slice according to the 1D tensor in the sequence length dimension cache_position (`torch.LongTensor`, *optional*): Cache positions. **loss_kwargs: @@ -1426,6 +1433,7 @@ def get_image_features( image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( @@ -1442,7 +1450,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, ) -> Union[Tuple, AriaCausalLMOutputWithPast]: @@ -1552,7 +1560,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1584,7 +1592,7 @@ def prepare_inputs_for_generation( pixel_mask=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( @@ -1593,7 +1601,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 295e2dcb7465..5c40473a18f7 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -45,6 +45,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_available from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from ..llama.configuration_llama import LlamaConfig @@ -1222,6 +1223,8 @@ def _init_weights(self, module): class AriaPreTrainedModel(LlamaPreTrainedModel): + _supports_attention_backend = False + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): @@ -1301,8 +1304,9 @@ class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a `ModelOutput` object. - num_logits_to_keep (`int`, *optional*, defaults to 0): - Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`. + Otherwise, slice according to the 1D tensor in the sequence length dimension cache_position (`torch.LongTensor`, *optional*): Cache positions. **loss_kwargs: @@ -1403,6 +1407,7 @@ def get_image_features( image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( @@ -1419,7 +1424,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, ) -> Union[Tuple, AriaCausalLMOutputWithPast]: @@ -1529,7 +1534,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1561,7 +1566,7 @@ def prepare_inputs_for_generation( pixel_mask=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( @@ -1570,7 +1575,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 20a3247be2bb..edfc162a032a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_2_ssm_available, @@ -1466,6 +1467,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1481,7 +1483,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1491,10 +1493,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1537,7 +1541,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1602,7 +1607,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7fb35f48fb3b..93fb274e4d4d 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -54,6 +54,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, @@ -1182,6 +1183,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): class BambaForCausalLM(LlamaForCausalLM): + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1197,7 +1199,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1207,10 +1209,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1242,7 +1246,7 @@ def forward( output_hidden_states, return_dict, cache_position, - num_logits_to_keep, + logits_to_keep, **kwargs, ) @@ -1293,7 +1297,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9c7207adfc1f..7337ae6acf49 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -48,6 +48,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_cohere import CohereConfig @@ -421,6 +422,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -808,6 +810,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -823,7 +826,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -833,10 +836,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -879,7 +884,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 6ea8fd6c8356..17eb3f6a3434 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -317,7 +317,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -327,10 +327,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -373,7 +375,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 0b38c89d75a5..e4bb8bb687e1 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -39,6 +39,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config @@ -421,6 +422,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -780,6 +782,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -795,7 +798,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -805,10 +808,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -851,7 +856,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits * self.logit_scale # main diff from Llama loss = None @@ -879,7 +885,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -934,8 +940,8 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 78419e78c08b..9c6872129c49 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -545,7 +545,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -600,8 +600,8 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a2373d345412..5ad827689b41 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -35,6 +35,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_dbrx import DbrxConfig @@ -1257,6 +1258,7 @@ def set_decoder(self, decoder: DbrxModel): def get_decoder(self) -> DbrxModel: return self.transformer + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1273,7 +1275,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r"""Forward function for causal language modeling. @@ -1283,10 +1285,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1333,7 +1337,8 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for Dbrx - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index bba4e646599f..c262340aacf9 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -51,6 +51,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_diffllama import DiffLlamaConfig @@ -599,6 +600,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = False def _init_weights(self, module): std = self.config.initializer_range @@ -1045,6 +1047,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1060,7 +1063,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1070,10 +1073,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1116,7 +1121,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 2c8c84670652..c6bdf18093d4 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -432,6 +432,7 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: int): class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = False + _supports_attention_backend = False class DiffLlamaModel(LlamaModel): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index b42a222f6ce9..6944f91b9758 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -44,6 +44,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig @@ -1626,6 +1627,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward( @@ -1641,7 +1643,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1650,10 +1652,13 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1696,7 +1701,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1865,7 +1871,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1873,10 +1879,13 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1949,7 +1958,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index aacf52fe31c6..01d09b703d8e 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -36,6 +36,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..chameleon.modeling_chameleon import ( ChameleonPreTrainedModel, ChameleonVQVAEEncoderConvDownsample, @@ -1071,6 +1072,7 @@ def __init__(self, config): super().__init__(config) self.model = Emu3TextModel(config) + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") def forward(**super_kwargs): @@ -1080,10 +1082,13 @@ def forward(**super_kwargs): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1177,7 +1182,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1185,10 +1190,13 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1261,7 +1269,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c0fad1ab66d5..f499801d2170 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -46,6 +46,7 @@ is_flash_attn_greater_or_equal_2_10, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_falcon import FalconConfig @@ -1176,6 +1177,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1196,7 +1198,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1204,10 +1206,12 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1227,7 +1231,8 @@ def forward( ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + lm_logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 66e975edaa53..caaf2c60f519 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -46,6 +46,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_gemma import GemmaConfig @@ -387,6 +388,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -777,6 +779,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -792,7 +795,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -802,10 +805,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -848,7 +853,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 29b6f8a19461..9c015d37c2f4 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -474,10 +474,12 @@ def forward(**super_kwargs): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e64559b26650..c065136a2871 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -44,6 +44,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config @@ -417,6 +418,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -782,6 +784,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -797,7 +800,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -807,10 +810,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -857,7 +862,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) @@ -888,7 +894,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -943,8 +949,8 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5f21fc6bfffd..4e9b6ea95dae 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -540,7 +540,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -585,7 +585,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) @@ -616,7 +617,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten: has a special cache type, `HybridCache` @@ -671,8 +672,8 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 3e5107c561df..a3461ffd71cb 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -46,6 +46,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_glm import GlmConfig @@ -402,6 +403,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -787,6 +789,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -802,7 +805,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -812,10 +815,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -858,7 +863,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 3c887d3a1b91..4549cdd5d70b 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -40,6 +40,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_granite import GraniteConfig @@ -402,6 +403,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -790,6 +792,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GRANITE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -805,7 +808,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -815,10 +818,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -861,7 +866,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling # main diff with Llama loss = None diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 698280085f18..f23ae4a673c3 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -245,7 +245,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -271,7 +271,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling # main diff with Llama loss = None diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 7eed89b4af33..71518c4a9aa8 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -47,6 +47,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_helium import HeliumConfig @@ -389,6 +390,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -774,6 +776,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(HELIUM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -789,7 +792,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -799,10 +802,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -845,7 +850,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 4e819811a984..3aaf46d63df2 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -37,6 +37,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig @@ -1508,6 +1509,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1525,7 +1527,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" Args: @@ -1535,10 +1537,12 @@ def forward( Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1604,7 +1608,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1648,7 +1653,7 @@ def prepare_inputs_for_generation( pixel_values=None, pixel_attention_mask=None, image_hidden_states=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take @@ -1677,8 +1682,8 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep if image_hidden_states is not None: pixel_values = None diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 31cf1a2e8f11..e4cc8bda569f 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1242,7 +1242,7 @@ def prepare_inputs_for_generation( pixel_values=None, pixel_attention_mask=None, image_hidden_states=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take @@ -1271,8 +1271,8 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep if image_hidden_states is not None: pixel_values = None diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fd6b1bae31b1..24aeb9890b9f 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -45,6 +45,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, @@ -1433,9 +1434,9 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, @@ -1450,7 +1451,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[Union[int, None]] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1460,10 +1461,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1510,10 +1513,8 @@ def forward( ) hidden_states = outputs[0] - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) - else: - logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1595,7 +1596,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 433ca61fabec..fca47eb3fa0d 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -42,6 +42,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_jetmoe import JetMoeConfig @@ -1274,6 +1275,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1290,7 +1292,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1299,10 +1301,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: """ @@ -1329,7 +1333,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8cbb12628c0a..361ae15c3127 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -47,6 +47,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_llama import LlamaConfig @@ -391,6 +392,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -776,6 +778,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -791,7 +794,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -801,10 +804,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -847,7 +852,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 93d7465291cb..fcf016f28f81 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -31,6 +31,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava import LlavaConfig @@ -380,6 +381,7 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in return final_embedding, final_attention_mask, final_labels, position_ids + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -398,7 +400,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -407,10 +409,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -490,7 +494,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -534,7 +538,7 @@ def prepare_inputs_for_generation( pixel_values=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -545,7 +549,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 51df47233b26..8bff9dc90061 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -34,6 +34,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next import LlavaNextConfig @@ -752,6 +753,7 @@ def get_image_features( image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -771,7 +773,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" Args: @@ -780,10 +782,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -871,7 +875,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -916,7 +920,7 @@ def prepare_inputs_for_generation( image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -927,7 +931,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 257c81aa8fe4..c82d52bfdaa9 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -33,6 +33,7 @@ from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -787,6 +788,7 @@ def get_image_features( image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -807,7 +809,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -819,10 +821,12 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -967,7 +971,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -1014,7 +1018,7 @@ def prepare_inputs_for_generation( image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing @@ -1025,7 +1029,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 89975a745b79..580f890b4266 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -335,7 +335,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -347,10 +347,12 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -495,7 +497,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -542,7 +544,7 @@ def prepare_inputs_for_generation( image_sizes=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing @@ -553,7 +555,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 5c5471479e86..f1cf7a6c2dca 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -32,6 +32,7 @@ add_start_docstrings, logging, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_onevision import LlavaOnevisionConfig @@ -568,6 +569,7 @@ def get_video_features( return video_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, @@ -589,7 +591,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" Args: @@ -598,10 +600,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -734,7 +738,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -782,7 +786,7 @@ def prepare_inputs_for_generation( image_sizes_videos=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -793,7 +797,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 635cda9cc8f0..cc62d378ebae 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -32,6 +32,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mistral import MistralConfig @@ -363,6 +364,7 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -777,6 +779,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -792,7 +795,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -802,10 +805,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -848,7 +853,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8cf2d0e8fa8d..034ddba8c484 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -55,6 +55,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mixtral import MixtralConfig @@ -485,6 +486,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -996,6 +998,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1012,7 +1015,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1022,10 +1025,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1074,7 +1079,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a6069f69b334..a16e4c5a16d9 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -466,7 +466,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -476,10 +476,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -528,7 +530,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index b40c366a6d75..d1f83e13d8d4 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -35,6 +35,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -1872,6 +1873,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig") def forward( @@ -1890,7 +1892,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1900,10 +1902,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1950,7 +1954,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() loss = None if labels is not None: @@ -2014,6 +2019,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") def forward( @@ -2034,7 +2040,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -2043,10 +2049,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -2140,7 +2148,7 @@ def forward( output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) return outputs @@ -2158,7 +2166,7 @@ def prepare_inputs_for_generation( past_key_values=None, use_cache=False, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -2190,8 +2198,8 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep model_inputs.update( { diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a1c15b7a0b37..3796e2dc5f35 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -48,6 +48,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -1788,6 +1789,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1803,7 +1805,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, MoshiCausalLMOutputWithPast]: r""" Args: @@ -1812,10 +1814,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1861,7 +1865,8 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -2446,7 +2451,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, user_delay_pattern_mask=None, moshi_delay_pattern_mask=None, kwargs_depth_decoder=None, @@ -2463,7 +2468,7 @@ def prepare_inputs_for_generation( cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 54f774f0b942..8ae6e9c77fac 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -46,6 +46,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_nemotron import NemotronConfig @@ -1023,6 +1024,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy (doc string different) @@ -1039,7 +1041,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1049,10 +1051,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1094,7 +1098,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 34e9f7259cdf..c2e1ae15b4b5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -26,6 +26,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmo import OlmoConfig @@ -367,6 +368,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -752,6 +754,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -767,7 +770,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -777,10 +780,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -823,7 +828,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index a6a19265015b..163956d61a22 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -25,6 +25,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmo2 import Olmo2Config @@ -368,6 +369,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -753,6 +755,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -768,7 +771,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -778,10 +781,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -824,7 +829,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 5c78138c1a03..47126da95647 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -38,6 +38,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_olmoe import OlmoeConfig @@ -756,7 +757,6 @@ def forward( "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", OLMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe class OlmoePreTrainedModel(PreTrainedModel): config_class = OlmoeConfig base_model_prefix = "model" @@ -765,7 +765,6 @@ class OlmoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -1186,6 +1185,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1202,7 +1202,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1212,10 +1212,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1262,7 +1264,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 36a9e59118b6..5889f92f3c0d 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -32,6 +32,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_paligemma import PaliGemmaConfig @@ -412,6 +413,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -429,7 +431,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: r""" Args: @@ -438,10 +440,12 @@ def forward( config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -532,7 +536,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs.logits @@ -581,7 +585,7 @@ def prepare_inputs_for_generation( attention_mask=None, token_type_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, labels=None, **kwargs, ): @@ -594,7 +598,7 @@ def prepare_inputs_for_generation( position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, **kwargs, ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8336ab5a2cf5..d1cb49529428 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -46,6 +46,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_persimmon import PersimmonConfig @@ -830,6 +831,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -845,7 +847,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -854,10 +856,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -900,7 +904,8 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for Persimmon - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 33439dff756e..7d360b1ed41e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -31,6 +31,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_phi import PhiConfig @@ -363,6 +364,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -750,6 +752,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -765,7 +768,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -775,10 +778,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -821,7 +826,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index cf905cb62e90..e86e028b4027 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -47,6 +47,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_phi3 import Phi3Config @@ -432,6 +433,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True _version = "0.0.5" def _init_weights(self, module): @@ -847,6 +849,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -862,7 +865,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -872,10 +875,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -918,7 +923,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -945,7 +951,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -970,7 +976,7 @@ def prepare_inputs_for_generation( cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 2b1a19be4ae2..27f7c42f5bb8 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -275,7 +275,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -300,7 +300,7 @@ def prepare_inputs_for_generation( cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index b540dd18300e..ba4b76650730 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -40,6 +40,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torch_fx_available from .configuration_phimoe import PhimoeConfig @@ -901,7 +902,6 @@ def forward( "The bare Phimoe Model outputting raw hidden-states without any specific head on top.", PHIMOE_START_DOCSTRING, ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralPreTrainedModel with Mixtral->Phimoe class PhimoePreTrainedModel(PreTrainedModel): config_class = PhimoeConfig base_model_prefix = "model" @@ -910,7 +910,6 @@ class PhimoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -1365,6 +1364,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -1382,7 +1382,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1392,10 +1392,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python @@ -1445,7 +1447,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1488,7 +1491,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the @@ -1513,7 +1516,7 @@ def prepare_inputs_for_generation( cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f8be4e3740f4..96cd6a6aa32e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -32,6 +32,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2 import Qwen2Config @@ -376,6 +377,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -761,6 +763,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -776,7 +779,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -786,10 +789,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -832,7 +837,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0f61323f4030..ad61003c8602 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -49,6 +49,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_moe import Qwen2MoeConfig @@ -1247,6 +1248,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1263,7 +1265,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1273,10 +1275,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1323,7 +1327,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 4cdab6dc4d2d..55a85a9a1fa2 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -48,6 +48,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_stablelm import StableLmConfig @@ -1086,6 +1087,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -1102,7 +1104,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1111,10 +1113,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1156,7 +1160,8 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for StableLm - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 500f96e3e30b..57898bc8d616 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -51,6 +51,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_starcoder2 import Starcoder2Config @@ -368,6 +369,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -773,6 +775,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -788,7 +791,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -798,10 +801,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -844,7 +849,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 293fb10ae277..f592da818549 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -31,6 +31,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_video_llava import VideoLlavaConfig @@ -409,6 +410,7 @@ def get_video_features(self, pixel_values_videos: torch.FloatTensor, vision_feat return video_features, num_frames + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -428,7 +430,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Args: @@ -437,10 +439,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -579,7 +583,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -625,7 +629,7 @@ def prepare_inputs_for_generation( pixel_values_videos=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -636,7 +640,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 0daaa8327b63..8ef881b771cb 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -31,6 +31,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_vipllava import VipLlavaConfig @@ -373,6 +374,7 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in return final_embedding, final_attention_mask, final_labels, position_ids + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy @@ -391,7 +393,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]: r""" Args: @@ -400,10 +402,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -479,7 +483,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, ) logits = outputs[0] @@ -521,7 +525,7 @@ def prepare_inputs_for_generation( pixel_values=None, attention_mask=None, cache_position=None, - num_logits_to_keep=None, + logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -532,7 +536,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 761c799bdcdc..a25cfbc42862 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -48,6 +48,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, @@ -1217,6 +1218,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1232,7 +1234,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1242,10 +1244,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1289,7 +1293,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1355,7 +1360,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index e8416c9f116e..064decb14dba 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -19,7 +19,12 @@ import packaging.version from .. import __version__ -from . import ExplicitEnum +from . import ExplicitEnum, is_torch_available, is_torchdynamo_compiling + + +# This is needed in case we deprecate a kwarg of a function/method being compiled +if is_torch_available(): + import torch # noqa: F401 class Action(ExplicitEnum): @@ -40,6 +45,7 @@ def deprecate_kwarg( ): """ Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified. + Note that is decorator is `torch.compile`-safe, i.e. it will not cause graph breaks (but no warning will be displayed if compiling). This decorator allows you to: - Notify users when a keyword argument is deprecated. @@ -158,7 +164,8 @@ def wrapped_func(*args, **kwargs): # raise error or notify user if minimum_action == Action.RAISE: raise ValueError(message) - elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS): + # If we are compiling, we do not raise the warning as it would break compilation + elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling(): # DeprecationWarning is ignored by default, so we use FutureWarning instead warnings.warn(message, FutureWarning, stacklevel=2) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ba61d4b43677..b47566354b44 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2029,10 +2029,10 @@ def test_generate_compile_model_forward(self): self._check_similar_generate_outputs(dynamic_result, compiled_result) @pytest.mark.generate - def test_generate_methods_with_num_logits_to_keep(self): + def test_generate_methods_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") config, inputs_dict = self.prepare_config_and_inputs_for_generate() config.use_cache = True @@ -2047,17 +2047,17 @@ def test_generate_methods_with_num_logits_to_keep(self): "do_sample": False, } - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + # Setting logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) + # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) without_all_logits = model.generate(**inputs_dict, **generation_kwargs) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) @pytest.mark.generate - def test_assisted_decoding_with_num_logits_to_keep(self): + def test_assisted_decoding_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") @@ -2081,9 +2081,9 @@ def test_assisted_decoding_with_num_logits_to_keep(self): "output_scores": True, } - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + # Setting logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) + # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) without_all_logits = model.generate(**inputs_dict, **generation_kwargs) self._check_similar_generate_outputs(with_all_logits, without_all_logits) diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 9356824dabda..16be88f94949 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -531,7 +531,7 @@ def test_simple_generate(self): # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist if self.cuda_compute_capability_major_version == 8: with torch.no_grad(): - logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits + logits = self.model(input_ids=input_ids, logits_to_keep=40).logits EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf259fabe302..148dc1a86575 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4780,21 +4780,21 @@ def test_torch_compile_for_training(self): for name, param in model._orig_mod.named_parameters(): torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4) - def test_forward_with_num_logits_to_keep(self): + def test_forward_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `logits_to_keep` argument.") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() batch_size, sequence_length = inputs["input_ids"].shape vocab_size = config.get_text_config().vocab_size model = model_class(config).to(device=torch_device).eval() - # some models have labels but `num_logits_to_keep` should not be used in train mode + # some models have labels but `logits_to_keep` should not be used in train mode _ = inputs.pop("labels", None) - # num_logits_to_keep=0 is a special case meaning "keep all logits" - all_logits = model(**inputs, num_logits_to_keep=0).logits - last_token_logits = model(**inputs, num_logits_to_keep=1).logits + # logits_to_keep=0 is a special case meaning "keep all logits" + all_logits = model(**inputs, logits_to_keep=0).logits + last_token_logits = model(**inputs, logits_to_keep=1).logits # Assert all shapes are correct self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) diff --git a/tests/utils/test_deprecation.py b/tests/utils/test_deprecation.py index e8e7e671ad2e..bf9f63e070b9 100644 --- a/tests/utils/test_deprecation.py +++ b/tests/utils/test_deprecation.py @@ -17,10 +17,15 @@ from parameterized import parameterized -from transformers import __version__ +from transformers import __version__, is_torch_available +from transformers.testing_utils import require_torch_gpu from transformers.utils.deprecation import deprecate_kwarg +if is_torch_available(): + import torch + + INFINITE_VERSION = "9999.0.0" @@ -168,3 +173,23 @@ def dummy_function(new_name=None, **kwargs): with self.assertWarns(FutureWarning): result = dummy_function(deprecated_name="old_value", new_name="new_value") self.assertEqual(result, "new_value") + + @require_torch_gpu + def test_compile_safe(self): + @deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION) + def dummy_function(new_factor=None, **kwargs): + return new_factor * torch.ones(1, device="cuda") + + compiled_function = torch.compile(dummy_function, fullgraph=True) + + # Check that we can correctly call the compiled function with the old name, without raising errors + out = compiled_function(deprecated_factor=2) + self.assertEqual(out.item(), 2) + + # Check that we can correctly call the compiled function with the new name, without raising errors + out = compiled_function(new_factor=2) + self.assertEqual(out.item(), 2) + + # Check that we can correctly call the compiled function with both names, without raising errors + out = compiled_function(new_factor=2, deprecated_factor=10) + self.assertEqual(out.item(), 2)