Skip to content

GenerationMixin: model_kwargs not passed to model in assisted decoding #25020

@sinking-point

Description

@sinking-point

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
  • Python version: 3.11.4
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
assist = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

inputs = tokenizer("The first rule of fight", return_tensors='pt')

outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long))
print(tokenizer.decode(outputs[0]))

# output: The first rule of fight!!!!!!!!!!!!!!!

outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long), num_beams=1, assistant_model=assist)
print(tokenizer.decode(outputs[0]))

# output: The first rule of fight-or-flight is to be prepared for the enemy. If you are

Expected behavior

I would expect the outputs to be the same for the assisted generation as for the regular generation, as the token_type_ids is being passed into generate in both cases. It is expected that the generate method passes extra arguments to the model via its prepare_inputs_for_generation method.

In fact, the assisted generation does not forward the model_kwargs to the model as the other generation methods do.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions