-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Closed
Description
System Info
transformersversion: 4.31.0- Platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
- Python version: 3.11.4
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.1
- Accelerate version: 0.21.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
assist = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("The first rule of fight", return_tensors='pt')
outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long))
print(tokenizer.decode(outputs[0]))
# output: The first rule of fight!!!!!!!!!!!!!!!
outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long), num_beams=1, assistant_model=assist)
print(tokenizer.decode(outputs[0]))
# output: The first rule of fight-or-flight is to be prepared for the enemy. If you areExpected behavior
I would expect the outputs to be the same for the assisted generation as for the regular generation, as the token_type_ids is being passed into generate in both cases. It is expected that the generate method passes extra arguments to the model via its prepare_inputs_for_generation method.
In fact, the assisted generation does not forward the model_kwargs to the model as the other generation methods do.
Metadata
Metadata
Assignees
Labels
No labels