diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b5569dc98906..13523539205b 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1163,6 +1163,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1208,25 +1209,26 @@ def forward( torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) - lm_logits = self.lm_head(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: # Flatten the tokens loss = self.loss_function( - lm_logits, + logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] + output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, - logits=lm_logits, + logits=logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions,