diff --git a/tests/conftest.py b/tests/conftest.py index c7a349f1e9e2..7063be069ed6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -385,8 +385,8 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] - def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: - return self.model.encode(prompts) + def encode(self, prompts: List[str], **kwargs) -> List[List[torch.Tensor]]: + return self.model.encode(prompts, **kwargs) def __enter__(self): return self diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index 6556998b68a7..35f1fc350e33 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -1,6 +1,6 @@ -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. +"""Compare the outputs of HF and vLLM for embedding models. -Run `pytest tests/models/test_llama_embedding.py`. +Run `pytest tests/models/test_embedding.py`. """ import pytest import torch @@ -8,6 +8,8 @@ MODELS = [ "intfloat/e5-mistral-7b-instruct", + "ssmits/Qwen2-7B-Instruct-embed-base", + "Alibaba-NLP/gte-Qwen2-7B-instruct", ] @@ -29,7 +31,7 @@ def test_models( dtype: str, ) -> None: with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: - hf_outputs = hf_model.encode(example_prompts) + hf_outputs = hf_model.encode(example_prompts, normalize_embeddings=True) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/vllm/config.py b/vllm/config.py index 35945e34452d..e2ed68585ee4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -195,6 +195,10 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) + # FIXME: Special handling for gte-Qwen2 + if "gte-Qwen2" in self.model: + architectures = ["Qwen2EmbeddingModel"] + self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) @@ -277,6 +281,10 @@ def verify_with_parallel_config( pipeline_parallel_size = parallel_config.pipeline_parallel_size architectures = getattr(self.hf_config, "architectures", []) + # FIXME: Special handling for gte-Qwen series + if "gte-Qwen2" in self.model: + architectures = ["Qwen2EmbeddingModel"] + if not all(arch in _PP_SUPPORTED_MODELS for arch in architectures) and pipeline_parallel_size > 1: raise NotImplementedError( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46..587cb883ea68 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -27,6 +27,9 @@ def get_model_architecture( and model_config.quantization != "fp8" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + # FIXME: Special handling for gte-Qwen2 + if "gte-Qwen2" in model_config.model: + architectures = ["Qwen2EmbeddingModel"] for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94c3cea98be7..4fbe83853cfc 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -81,6 +81,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "Qwen2EmbeddingModel": ("qwen2_embedding", "Qwen2EmbeddingModel"), } _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} diff --git a/vllm/model_executor/models/qwen2_embedding.py b/vllm/model_executor/models/qwen2_embedding.py new file mode 100644 index 000000000000..e1de1547a49a --- /dev/null +++ b/vllm/model_executor/models/qwen2_embedding.py @@ -0,0 +1,58 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class Qwen2EmbeddingModel(nn.Module): + """A model that uses Qwen2 with additional embedding functionalities. + This class encapsulates the Qwen2ForCausalLM and provides an interface for + embedding operations and customized pooling functions. + Attributes: + model: An instance of Qwen2ForCausalLM used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.model = Qwen2ForCausalLM(config, cache_config, quant_config, + lora_config) + + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights)