-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Closed
Description
System Info
transformersversion: 4.40.1- Platform: Linux-4.18.0-513.24.1.el8_9.x86_64-x86_64-with-glibc2.28
- Python version: 3.10.13
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.1
- Accelerate version: 0.29.2
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: FSDP
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Wrapping a LlamaModel with FSDP results in the following error during a forward pass;
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
outputs = self.model(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1016, in forward
layer_outputs = decoder_layer(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
ret = function(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 739, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'offload_to_cpu'
This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608
If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77
Expected behavior
Remove line 749 or add **kwargs to forward().
Metadata
Metadata
Assignees
Labels
No labels