From 2719fac48cc83bc3fcf000a2be0077e3fdc6dde7 Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Thu, 10 Oct 2024 09:58:00 +0000 Subject: [PATCH 01/10] Add GLM-4v support and meet vllm==0.6.2 --- examples/offline_inference_vision_language.py | 15 + .../decoder_only/vision_language/test_glm4.py | 107 ++++++ vllm/model_executor/models/chatglm.py | 349 +++++++++++++++--- .../models/glm4_vision_encoder.py | 298 +++++++++++++++ vllm/model_executor/models/registry.py | 1 + 5 files changed, 718 insertions(+), 52 deletions(-) create mode 100644 tests/models/decoder_only/vision_language/test_glm4.py create mode 100644 vllm/model_executor/models/glm4_vision_encoder.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 5dd539c3d5ee..7361deac2ea2 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -300,6 +300,20 @@ def run_mllama(question: str, modality: str): return llm, prompt, stop_token_ids +# GLM-4v +def run_glm4v(question): + model_name = "THUDM/glm-4v-9b" + + llm = LLM(model=model_name, + tensor_parallel_size=1, + max_model_len=8192, + trust_remote_code=True, + enforce_eager=True) + prompt = question + stop_token_ids = [151329, 151336, 151338] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -316,6 +330,7 @@ def run_mllama(question: str, modality: str): "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "mllama": run_mllama, + "glm4v": run_glm4v, } diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py new file mode 100644 index 000000000000..196b5647f6c0 --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -0,0 +1,107 @@ +# tests/models/decoder_only/vision_language/test_glm4v.py +import pytest +from typing import List, Optional, Tuple, Type +from vllm.multimodal.utils import rescale_image_size +from ....conftest import (IMAGE_ASSETS, HfRunner, + PromptImageInput, VllmRunner) +from ...utils import check_logprobs_close + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "What's the content of the image?", + "cherry_blossom": + "What is the season?", +}) + +models = ["THUDM/glm-4v-9b"] +target_dtype = "bfloat16" + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + # max_model_len should be greater than image_feature_size + with vllm_runner( + model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + stop_token_ids = [151329, 151336, 151338] + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + stop_token_ids=stop_token_ids) + for prompts, images in inputs + ] + with hf_runner(model, dtype=dtype) as hf_model: + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.transformer.output_layer + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + ) + for prompts, images in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( + hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 879795c0d595..d9f86e689e77 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -1,42 +1,228 @@ # coding=utf-8 # Adapted from -# https://github.com/THUDM/ChatGLM2-6B +# https://github.com/THUDM/GLM-4 """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from argparse import Namespace +from array import array +from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict import torch +from PIL import Image from torch import nn from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.base import MultiModalData +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) from vllm.transformers_utils.configs import ChatGLMConfig -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .interfaces import SupportsLoRA, SupportsMultiModal + +logger = init_logger(__name__) + + +def calculate_image_placeholder(vision_config): + return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 + + +def mm_input_mapper_for_glmv( + ctx: InputContext, + data: MultiModalData[object], +) -> Dict: + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + if tokenizer is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + try: + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": data + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + pixel_values = raw_batch_data['images'] + + return {'pixel_values': pixel_values} + + +def merge_glm_vision_embeddings( + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + boi_token_id: int, + eoi_token_id: int, +) -> torch.Tensor: + + boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0] + eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0] + + mask = torch.zeros_like(input_ids, dtype=torch.bool) + + for boi_pos, eoi_pos in zip(boi_positions, eoi_positions): + assert boi_pos < eoi_pos + mask[boi_pos:eoi_pos + 1] = True + inputs_embeds[mask] = vision_embeddings.view(-1, + vision_embeddings.shape[-1]) + return inputs_embeds + + +class GLMImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +def get_max_glmv_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(ChatGLMConfig) + + vision_config = getattr(hf_config, 'vision_config', None) + if vision_config is None: + return 1 + elif isinstance(vision_config, dict): + return calculate_image_placeholder(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def dummy_data_for_glmv( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: + hf_config = ctx.get_hf_config(ChatGLMConfig) + vision_config = getattr(hf_config, 'vision_config', None) + + if vision_config is None: + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) + seq_data = SequenceData(token_ids) + return seq_data, None + elif isinstance(vision_config, dict): + image_size = vision_config["image_size"] + image_placeholder_length = calculate_image_placeholder(vision_config) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + + [0] * image_placeholder_length + + [hf_config.eoi_token_id]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0] * (seq_len - image_placeholder_length - 2)) + seq_data = SequenceData(token_ids) + + mm_data = { + "image": Image.new("RGB", (image_size, image_size), color=0) + } + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def find_all_positions(input_ids: List[int], target: int) -> List[int]: + return [index for index, value in enumerate(input_ids) if value == target] + + +def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): + hf_config = ctx.get_hf_config(ChatGLMConfig) + vision_config = getattr(hf_config, 'vision_config', None) + + if vision_config is None: + return llm_inputs + elif isinstance(vision_config, dict): + image_placeholder_length = calculate_image_placeholder(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + input_ids = llm_inputs.get("prompt_token_ids") + position_ids = llm_inputs.get("position_ids") + tokenizer = cached_get_tokenizer( + ctx.model_config.model, + trust_remote_code=ctx.model_config.trust_remote_code) + + try: + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": llm_inputs['multi_modal_data']["image"], + "content": llm_inputs['prompt'] + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data + except Exception: + logger.error("Failed to process content (%s)", llm_inputs['prompt']) + raise + input_ids = raw_batch_data['input_ids'][0].tolist() + + if position_ids is None: + position_ids = list(range(len(input_ids))) + boi_token_id = hf_config.boi_token_id + eoi_token_id = hf_config.eoi_token_id + boi_positions = find_all_positions(input_ids, boi_token_id) + eoi_positions = find_all_positions(input_ids, eoi_token_id) + + assert len(boi_positions) == len(eoi_positions) + + new_input_ids = [] + new_position_ids = [] + final_processed_position = 0 + final_processed_position = 0 + + for boi_position, eoi_position in zip(boi_positions, eoi_positions): + assert boi_position < eoi_position + new_input_ids.extend(input_ids[final_processed_position:boi_position + + 1]) + new_position_ids.extend( + list(range(final_processed_position, boi_position + 1))) + new_input_ids.extend([input_ids[boi_position + 1]] * + image_placeholder_length) + new_position_ids.extend([boi_position + 1] * image_placeholder_length) + final_processed_position = eoi_position + + new_input_ids.extend(input_ids[final_processed_position:]) + new_position_ids.extend( + list(range(final_processed_position, len(input_ids)))) + + assert len(new_input_ids) == len(new_position_ids) + + llm_inputs["prompt_token_ids"] = new_input_ids + llm_inputs["position_ids"] = new_position_ids + return llm_inputs class GLMAttention(nn.Module): def __init__( self, - config: ChatGLMConfig, + config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -127,7 +313,7 @@ class GLMMLP(nn.Module): def __init__( self, - config: ChatGLMConfig, + config, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -170,7 +356,7 @@ class GLMBlock(nn.Module): def __init__( self, - config: ChatGLMConfig, + config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -241,10 +427,9 @@ class GLMTransformer(nn.Module): def __init__( self, - config: ChatGLMConfig, + config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -253,11 +438,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.start_layer, self.end_layer, self.layers = make_layers( - self.num_layers, - lambda prefix: GLMBlock(config, cache_config, quant_config), - prefix=f"{prefix}.layers", - ) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -272,16 +456,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(self.start_layer, self.end_layer): + for i in range(self.num_layers): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, - kv_cache=kv_caches[i - self.start_layer], + kv_cache=kv_caches[i], attn_metadata=attn_metadata, ) # Final layer norm. - if get_pp_group().is_last_rank and self.post_layer_norm: + if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states @@ -291,14 +475,17 @@ class ChatGLMModel(nn.Module): def __init__( self, - config: ChatGLMConfig, + config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + self.config = config + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, - config.hidden_size) + config.hidden_size, + quant_config=quant_config) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num @@ -308,37 +495,73 @@ def __init__( self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, quant_config=quant_config) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + + vision_config_flag = getattr(config, 'vision_config', None) + if vision_config_flag is not None: + self.vision_config = Namespace(**config.vision_config) + self.vision = EVA2CLIPModel(self.config, quant_config) + else: + self.vision = None + + def _parse_and_validate_image_input( + self, **kwargs: object) -> GLMImagePixelInputs: + + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is not None and self.vision is not None: + if isinstance(pixel_values, torch.Tensor): + if pixel_values.ndim > 2: + pixel_values = torch.concat(list(pixel_values)) + elif isinstance(pixel_values, list): + return torch.concat(pixel_values) + else: + raise TypeError("""pixel_values must be a torch.Tensor + or a list of torch.Tensor + """) + return GLMImagePixelInputs(pixel_values=pixel_values) def forward( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - inputs_embeds = self.embedding(input_ids) - else: - inputs_embeds = intermediate_tensors["hidden_states"] + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> torch.Tensor: + + inputs_embeds = self.embedding(input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input["pixel_values"] is not None: + pixel_values = image_input["pixel_values"].to( + dtype=inputs_embeds.dtype) + image_embeds = self.vision(pixel_values) + + boi_token_id = self.config.boi_token_id + eoi_token_id = self.config.eoi_token_id + + inputs_embeds = merge_glm_vision_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=image_embeds, + boi_token_id=boi_token_id, + eoi_token_id=eoi_token_id) # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, - position_ids=position_ids, + position_ids=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states -class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -356,6 +579,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, config: ChatGLMConfig, + multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -364,6 +588,7 @@ def __init__( self.config = config self.lora_config = lora_config + self.multimodal_config = multimodal_config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", @@ -375,19 +600,16 @@ def __init__( self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() - self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, **kwargs) return hidden_states def compute_logits( @@ -408,8 +630,24 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # Merge two ColumnParallelLinear into one MergedColumnParallelLinear + merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { + "transformer.vision.linear_proj.merged_proj.weight": { + "transformer.vision.linear_proj.gate_proj.weight": None, + "transformer.vision.linear_proj.dense_h_to_4h.weight": None, + } + } + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: + is_weight_to_be_merge = False + for _, merged_weight_dict in merged_weights_dict.items(): + if name in merged_weight_dict: + assert merged_weight_dict[name] is None + merged_weight_dict[name] = loaded_weight + is_weight_to_be_merge = True + if is_weight_to_be_merge: + continue if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: @@ -417,9 +655,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if is_pp_missing_parameter(name, self): - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + for combined_name, merged_weight_dict in merged_weights_dict.items(): + if combined_name in params_dict: + param = params_dict[combined_name] + combined_weight = torch.cat(list(merged_weight_dict.values()), + dim=0) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, combined_weight) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py new file mode 100644 index 000000000000..3213a8b29a10 --- /dev/null +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/GLM-4 +"""Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" +from argparse import Namespace +from typing import Optional + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class PatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(self.proj.weight.device) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class Attention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + q = q.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + k = k.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + v = v.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + + out = torch.nn.functional.scaled_dot_product_attention(q, + k, + v, + attn_mask=None, + dropout_p=0., + is_causal=False) + + output, _ = self.dense(out.transpose(1, 2).view(B, L, -1)) + output = self.output_dropout(output) + return output + + +class MLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class TransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = Attention(config, quant_config=quant_config) + self.mlp = MLP(config, quant_config=quant_config) + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class Transformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layers = nn.ModuleList([ + TransformerLayer(config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config) + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config) + + self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = PatchEmbedding(vision_config) + self.transformer = Transformer(vision_config, + quant_config=quant_config) + self.linear_proj = GLU(config, + in_features=config.hidden_size, + quant_config=quant_config) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f1d484521acb..9def82949452 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -91,6 +91,7 @@ # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), From bf61abefb90c5276fbb936827691ebc857a2e7e2 Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Fri, 11 Oct 2024 03:17:31 +0000 Subject: [PATCH 02/10] add GLM-4v to supported_models.rst --- docs/source/models/supported_models.rst | 6 ++++++ examples/offline_inference_vision_language.py | 3 ++- vllm/model_executor/models/chatglm.py | 5 +++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 084607c155cb..4a4e88a9f162 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -346,6 +346,12 @@ Text Generation - :code:`adept/fuyu-8b` etc. - - ✅︎ + * - :code:`ChatGLMModel` + - GLM-4v + - Image + - :code:`THUDM/glm-4v-9b` etc. + - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - Image\ :sup:`E+` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 7361deac2ea2..76dc2284c4d4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -301,7 +301,8 @@ def run_mllama(question: str, modality: str): # GLM-4v -def run_glm4v(question): +def run_glm4v(question: str, modality: str): + assert modality == "image" model_name = "THUDM/glm-4v-9b" llm = LLM(model=model_name, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index d9f86e689e77..f26c9f950dd3 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalInputs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -72,7 +73,7 @@ def mm_input_mapper_for_glmv( raise pixel_values = raw_batch_data['images'] - return {'pixel_values': pixel_values} + return MultiModalInputs({'pixel_values': pixel_values}) def merge_glm_vision_embeddings( From 8b4701d3339e534bc6dc9759a9d492e0a5323204 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 11 Oct 2024 12:53:58 +0800 Subject: [PATCH 03/10] Use proper capitalization --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 44503c40355e..3090e649976c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -347,7 +347,7 @@ Text Generation - - ✅︎ * - :code:`ChatGLMModel` - - GLM-4v + - GLM-4V - Image - :code:`THUDM/glm-4v-9b` etc. - From 38f64bde8d213d4d28537c4538164ea6fe6a9503 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 11 Oct 2024 04:58:10 +0000 Subject: [PATCH 04/10] format --- .../decoder_only/vision_language/test_glm4.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 196b5647f6c0..09eb4de67a1e 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -1,9 +1,10 @@ -# tests/models/decoder_only/vision_language/test_glm4v.py -import pytest from typing import List, Optional, Tuple, Type + +import pytest + from vllm.multimodal.utils import rescale_image_size -from ....conftest import (IMAGE_ASSETS, HfRunner, - PromptImageInput, VllmRunner) + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -16,6 +17,7 @@ models = ["THUDM/glm-4v-9b"] target_dtype = "bfloat16" + def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -30,15 +32,14 @@ def run_test( distributed_executor_backend: Optional[str] = None, ): # max_model_len should be greater than image_feature_size - with vllm_runner( - model, - max_model_len=4096, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: stop_token_ids = [151329, 151336, 151338] vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, @@ -52,12 +53,12 @@ def run_test( hf_model.model.get_output_embeddings = lambda: \ hf_model.model.transformer.output_layer hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - ) - for prompts, images in inputs + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + ) for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -69,6 +70,7 @@ def run_test( name_1="vllm", ) + @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", From 964a95555eddcccc1b48b3f05062893a9c57d92f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 11 Oct 2024 14:50:01 +0800 Subject: [PATCH 05/10] Fix registry --- vllm/model_executor/models/registry.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9def82949452..b42a71ac90c5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -26,8 +26,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), @@ -68,6 +67,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + # QWenLMHeadModel supports multimodal "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), @@ -92,6 +92,7 @@ "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), From 9325c6f6e70b81aaa6b388360d0cffc4155888b3 Mon Sep 17 00:00:00 2001 From: sixgod Date: Fri, 11 Oct 2024 18:00:26 +0800 Subject: [PATCH 06/10] Update test_glm4.py max_model_len --- tests/models/decoder_only/vision_language/test_glm4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 09eb4de67a1e..f43c61ba26e3 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -33,7 +33,7 @@ def run_test( ): # max_model_len should be greater than image_feature_size with vllm_runner(model, - max_model_len=4096, + max_model_len=2048, max_num_seqs=1, dtype=dtype, limit_mm_per_prompt={"image": mm_limit}, From 5f221c43d642b961962e88e79e388bfa22557d18 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 11 Oct 2024 14:03:56 +0000 Subject: [PATCH 07/10] Skip CI for glm4v because of OOM --- examples/offline_inference_vision_language.py | 4 ++-- tests/models/decoder_only/vision_language/test_glm4.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 76dc2284c4d4..8d6818e7dfd3 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -306,8 +306,8 @@ def run_glm4v(question: str, modality: str): model_name = "THUDM/glm-4v-9b" llm = LLM(model=model_name, - tensor_parallel_size=1, - max_model_len=8192, + max_model_len=2048, + max_num_seqs=2, trust_remote_code=True, enforce_eager=True) prompt = question diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index f43c61ba26e3..4d2b1599e87f 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -5,6 +5,7 @@ from vllm.multimodal.utils import rescale_image_size from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -34,7 +35,7 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, max_model_len=2048, - max_num_seqs=1, + max_num_seqs=2, dtype=dtype, limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, @@ -71,6 +72,7 @@ def run_test( ) +@large_gpu_test(min_gb=48) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", From 80f0a33200b3a46f36284ccf5824529bc05dd5d8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 11 Oct 2024 15:01:49 +0000 Subject: [PATCH 08/10] Fix test not working for newer transformers version --- .../decoder_only/vision_language/test_glm4.py | 12 ++++++ vllm/transformers_utils/tokenizer.py | 39 ++++++++++--------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 4d2b1599e87f..94ca08234e6a 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -3,6 +3,7 @@ import pytest from vllm.multimodal.utils import rescale_image_size +from vllm.transformers_utils.tokenizer import patch_padding_side from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....utils import large_gpu_test @@ -51,6 +52,17 @@ def run_test( for prompts, images in inputs ] with hf_runner(model, dtype=dtype) as hf_model: + hf_processor = hf_model.processor + patch_padding_side(hf_processor) + + def processor(*args, images=None, **kwargs): + processed_inputs = hf_processor(*args, **kwargs) + if images is not None: + processed_inputs["images"] = images + + return processed_inputs + + hf_model.processor = processor hf_model.model.get_output_embeddings = lambda: \ hf_model.model.transformer.output_layer hf_outputs_per_image = [ diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 85c339df4a76..94af2388d79d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -59,6 +59,26 @@ def __len__(self): return tokenizer +def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: + """Patch _pad method to accept `padding_side` for older tokenizers.""" + orig_pad = tokenizer._pad + + def _pad( + self: PreTrainedTokenizer, + *args, + padding_side: Optional[str] = None, + **kwargs, + ): + if padding_side is not None and padding_side != self.padding_side: + msg = ("`padding_side` argument is not supported by " + f"{type(tokenizer).__name__} and will be ignored.") + warnings.warn(msg, stacklevel=2) + + return orig_pad(*args, **kwargs) + + tokenizer._pad = MethodType(_pad, tokenizer) + + def get_tokenizer( tokenizer_name: Union[str, Path], *args, @@ -143,24 +163,7 @@ def get_tokenizer( if type(tokenizer).__name__ in ("ChatGLMTokenizer", "ChatGLM4Tokenizer"): assert isinstance(tokenizer, PreTrainedTokenizer) - orig_pad = tokenizer._pad - - # Patch _pad method to accept `padding_side` - def _pad( - self: PreTrainedTokenizer, - *args, - padding_side: Optional[str] = None, - **kwargs, - ): - if (padding_side is not None - and padding_side != self.padding_side): - msg = ("`padding_side` argument is not supported by " - "ChatGLMTokenizer and will be ignored.") - warnings.warn(msg, stacklevel=2) - - return orig_pad(*args, **kwargs) - - tokenizer._pad = MethodType(_pad, tokenizer) + patch_padding_side(tokenizer) if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( From 6bb546e9135808893f56de571c34e7e7a9e7674e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 11 Oct 2024 15:25:56 +0000 Subject: [PATCH 09/10] Try fix --- .../decoder_only/vision_language/test_glm4.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 94ca08234e6a..96362c495686 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -51,16 +51,22 @@ def run_test( stop_token_ids=stop_token_ids) for prompts, images in inputs ] + with hf_runner(model, dtype=dtype) as hf_model: hf_processor = hf_model.processor patch_padding_side(hf_processor) - def processor(*args, images=None, **kwargs): - processed_inputs = hf_processor(*args, **kwargs) - if images is not None: - processed_inputs["images"] = images + def processor(*args, text="", images=None, **kwargs): + if images is None: + return hf_processor(*args, **kwargs) - return processed_inputs + return hf_processor.apply_chat_template( + [{"role": "user", "image": images, "content": text}], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **kwargs, + ) hf_model.processor = processor hf_model.model.get_output_embeddings = lambda: \ From 0008b8f64a8edb4eae4e80f8449753f8840a3c1c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 11 Oct 2024 15:26:56 +0000 Subject: [PATCH 10/10] format --- tests/models/decoder_only/vision_language/test_glm4.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 96362c495686..47922a57f680 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -61,7 +61,11 @@ def processor(*args, text="", images=None, **kwargs): return hf_processor(*args, **kwargs) return hf_processor.apply_chat_template( - [{"role": "user", "image": images, "content": text}], + [{ + "role": "user", + "image": images, + "content": text + }], add_generation_prompt=True, tokenize=True, return_dict=True,