Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d4a79ce
adding option for 2.5
Aug 31, 2025
6b8f487
minor - arg in conversion script
Aug 31, 2025
d0697ce
getting started on modelling.py
Sep 1, 2025
26b8cbf
minor - shouldve been using modular
Sep 1, 2025
f5299d2
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 1, 2025
14d96d3
adressing comments + fixing datatype/device _get method
Sep 1, 2025
3ac06f9
minor
Sep 1, 2025
8bc5d73
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 4, 2025
0656367
Merge branch 'huggingface:main' into integrate_colqwen2.5_using_colqw…
sahil-kabir Sep 5, 2025
73b029b
commiting suggestion
sahil-kabir Sep 5, 2025
3aa8aa8
docs + first test
Sep 5, 2025
d4be146
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 5, 2025
f591764
ruff fix
Sep 10, 2025
9577aae
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 10, 2025
e9ea6b6
minor fix
Sep 10, 2025
6ae49f6
ruff fix
Sep 10, 2025
9297f9e
model fix
Sep 10, 2025
6a62d82
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 10, 2025
2032bd5
Merge branch 'huggingface:main' into integrate_colqwen2.5_using_colqw…
sahil-kabir Sep 12, 2025
a0a6245
model fix
Sep 13, 2025
db2df86
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 13, 2025
272a7dc
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Sep 27, 2025
5ca07ce
fine-grained check, with a hardcoded score from the original Hf imple…
Sep 27, 2025
961fb9f
minor ruff
Sep 27, 2025
b6b454e
Merge remote-tracking branch 'upstream/main' into integrate_colqwen2.…
yonigozlan Oct 3, 2025
76238d3
update tests values with CI hardware
yonigozlan Oct 3, 2025
0582b59
adding 2.5 to conversion script
Oct 19, 2025
30dc9d9
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
sahil-kabir Oct 19, 2025
26fe35c
Apply style fixes
github-actions[bot] Nov 3, 2025
673289b
Merge branch 'main' into integrate_colqwen2.5_using_colqwen2_modellin…
yonigozlan Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/en/model_doc/colqwen2.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,24 @@ print("Retrieval scores (query x image):")
print(scores)
```

You can also use checkpoints for `ColQwen2.5` that are **compatible with the ColQwen2 architecture**. This version of the model uses [Qwen2_5_VL](./qwen2_5_vl) as the backbone.

```python
import torch
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available

model_name = "Sahil-Kabir/colqwen2.5-v0.2-hf" # An existing compatible checkpoint

model = ColQwen2ForRetrieval.from_pretrained(
model_name,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa"
)
processor = ColQwen2Processor.from_pretrained(model_name)
```

## Notes

- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
Expand Down
22 changes: 19 additions & 3 deletions src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@

import torch
from huggingface_hub import snapshot_download
from peft import PeftModel
from safetensors import safe_open

from transformers import AutoConfig
from transformers import AutoConfig, AutoModel
from transformers.models.colqwen2 import ColQwen2ForRetrieval
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.utils import logging
Expand Down Expand Up @@ -69,7 +70,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d
original_state_dict[key] = f.get_tensor(key)

# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()

return original_state_dict
Expand Down Expand Up @@ -124,7 +125,21 @@ def convert_colqwen2_weights_to_hf(
config.is_composition = False

# Load the untrained model
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
vlm_name_or_path = getattr(config.vlm_config, "_name_or_path", None)
if vlm_name_or_path and "2.5" in str(vlm_name_or_path):
print(
"Detected colqwen2.5 adapters in vlm_config; loading base model %s and merging PEFT weights."
% vlm_name_or_path
)
base_model = AutoModel.from_pretrained(
vlm_name_or_path,
device_map="cpu",
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base_model, model_id)
model = peft_model.merge_and_unload()
else:
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")

# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
Expand Down Expand Up @@ -201,6 +216,7 @@ def convert_colqwen2_weights_to_hf(
help="Name or path of the original VLM backbone model",
default=None,
)

args = parser.parse_args()

convert_colqwen2_weights_to_hf(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/colqwen2/modeling_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def forward(
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)

if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/colqwen2/modular_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def forward(
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)

if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
Expand Down
55 changes: 52 additions & 3 deletions tests/models/colqwen2/test_modeling_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,61 @@ def test_model_integration_test(self):
[15.6562, 12.2656, 20.2969],
],
("cuda", 8): [
[15.0703, 8.7422, 15.0312],
[9.5078, 16.8906, 10.6250],
[15.6484, 12.3984, 20.4688],
[16.2812, 8.3672, 14.5703],
[9.4922, 17.1875, 10.3281],
[15.0312, 11.3984, 20.1719],
],
}
)
expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype)

assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"

@slow
def test_model_integration_test_2(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
This test uses a ColQwen2.5 checkpoint that is compatible with the ColQwen2 architecture.
"""
model = ColQwen2ForRetrieval.from_pretrained(
"Sahil-Kabir/colqwen2.5-v0.2-hf",
device_map=torch_device,
dtype=torch.bfloat16,
).eval()
processor = ColQwen2Processor.from_pretrained("Sahil-Kabir/colqwen2.5-v0.2-hf", trust_remote_code=True)

# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")

# Preprocess the examples
batch_images = processor(images=list(ds["image"])).to(torch_device)
batch_queries = processor(text=list(ds["query"])).to(torch_device)

with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings

# Compute retrieval scores
scores = processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
)

assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"

# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
expectations = Expectations(
{
("cuda", 8): [
[16.3750, 10.9375, 14.7500],
[11.3750, 16.8750, 12.0625],
[15.3125, 13.1250, 21.5000],
]
}
)
expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype)

assert torch.allclose(scores, expected_scores, atol=0.15), f"Expected scores {expected_scores}, got {scores}"