Skip to content

Running AG and SD when assistant and target models are on different devices #35099

@jmamou

Description

@jmamou

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.26.3
  • Safetensors version: 0.4.1
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@zucchini-nlp @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Modify the example from https://huggingface.co/docs/transformers/main/en/generation_strategies#speculative-decoding and put target and draft models on different devices

from transformers import AutoModelForCausalLM, AutoTokenizer

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(checkpoint).to('cuda:0')
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to('cuda:1')
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

We get the following error:

Traceback (most recent call last):
  File "/home/jmamou/dynamicSL/scripts/run_sd.py", line 16, in <module>
    outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, max_new_tokens=512)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jmamou/dynamicSL/transformers/src/transformers/generation/utils.py", line 2197, in generate
    result = self._assisted_decoding(
  File "/home/jmamou/dynamicSL/transformers/src/transformers/generation/utils.py", line 4300, in _assisted_decoding
    outputs = self(**model_inputs)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jmamou/dynamicSL/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 1173, in forward
    outputs = self.gpt_neox(
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jmamou/dynamicSL/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 878, in forward
    inputs_embeds = self.embed_in(input_ids)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/jmamou/miniconda3/envs/bench_generation/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

The error does not occur when both models are on the device.

Expected behavior

['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions