-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Description
System Info
transformers
version: 4.45.0.dev0- Platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
- Python version: 3.10.14
- Huggingface_hub version: 0.24.5
- Safetensors version: 0.4.4
- Accelerate version: 0.33.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA GeForce RTX 4090 D
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi, I'm finetuning the newly-released Qwen2VLForConditionalGeneration
model by LoRA. I'm building the model by
Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2", torch_dtype=torch.float16
)
I found attn_implementation="flash_attention_2"
activates Qwen2VLFlashAttention2
which will throw a out-of-index error during training. When I switch to attn_implementation="sdpa"
, the error does not come out and training goes smoothly.
After some time of debugging, I located that the problem comes from this line where rotary_seq_len
does not properly reflect the length of the input sequence but rather the real length minus 1. I modified this line to rotary_seq_len = cache_position[-1] + 1
in my local transformers offline package and it turns out that the training with flash_attention_2
goes smoothly.
My input batch to the model is as follow:
batch
input_ids: Tensor (B, seq_len)
attention_mask: Tensor (B, seq_len)
labels: Tensor (B, seq_len)
pixel_values: Tensor (B, res_h, res_w) # res_h and res_w are the shape of image after processor()
image_grid_thw: Tensor (B, 3)
I suspect that my input batch to the model has the correct shape, so I'm wondering whether my tiny workaround is the optimal solution to the problem. I really appreciate it if you could tell me some better solutions.
Expected behavior
As Reproduction section. Thanks for your patience for my issue.