Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def generate_greedy_logprobs_limit(
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
def encode(self, prompts: List[str], **kwargs) -> List[List[torch.Tensor]]:
return self.model.encode(prompts, **kwargs)

def __enter__(self):
return self
Expand Down
8 changes: 5 additions & 3 deletions tests/models/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
"""Compare the outputs of HF and vLLM for embedding models.

Run `pytest tests/models/test_llama_embedding.py`.
Run `pytest tests/models/test_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F

MODELS = [
"intfloat/e5-mistral-7b-instruct",
"ssmits/Qwen2-7B-Instruct-embed-base",
"Alibaba-NLP/gte-Qwen2-7B-instruct",
]


Expand All @@ -29,7 +31,7 @@ def test_models(
dtype: str,
) -> None:
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
hf_outputs = hf_model.encode(example_prompts, normalize_embeddings=True)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
Expand Down
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def _verify_tokenizer_mode(self) -> None:

def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
# FIXME: Special handling for gte-Qwen2
if "gte-Qwen2" in self.model:
architectures = ["Qwen2EmbeddingModel"]
Comment on lines +198 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcoded case based on the model id/path used is not acceptable. For instance, this wouldn't work in the case where a user has downloaded the model locally and passed in a path like --model ~/my-model/

Copy link
Author

@0xWelt 0xWelt Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gte-Qwen2 embedding model's architecture is "Qwen2ForCausalLM", which is the same as Qwen2 LLMs. Is there any better solution to eliminate this ambiguity?

Perhaps we can add an option in argparser to specify whether it is an embedding model, rather than searching through the model architecture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about working with the upstream to change or add an extra "Qwen2EmbeddingModel" in the "architectures" list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#9424 should be able to solve this.


self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)

Expand Down Expand Up @@ -277,6 +281,10 @@ def verify_with_parallel_config(

pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
# FIXME: Special handling for gte-Qwen series
if "gte-Qwen2" in self.model:
architectures = ["Qwen2EmbeddingModel"]

if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
raise NotImplementedError(
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def get_model_architecture(
and model_config.quantization != "fp8"
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
# FIXME: Special handling for gte-Qwen2
if "gte-Qwen2" in model_config.model:
architectures = ["Qwen2EmbeddingModel"]

for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
"Qwen2EmbeddingModel": ("qwen2_embedding", "Qwen2EmbeddingModel"),
}

_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
Expand Down
58 changes: 58 additions & 0 deletions vllm/model_executor/models/qwen2_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import Qwen2Config

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class Qwen2EmbeddingModel(nn.Module):
"""A model that uses Qwen2 with additional embedding functionalities.
This class encapsulates the Qwen2ForCausalLM and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Qwen2ForCausalLM used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.model = Qwen2ForCausalLM(config, cache_config, quant_config,
lora_config)

self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)