Skip to content

Conversation

alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Aug 30, 2024

FIX #962
FIX #7017
FIX #7192

Currently Qwen models in VLLM skip loading the visual transformer weights. This PR adds support for loading the visual weights (if they're present) and adds multimodal support, e.g., for qwen-vl and qwen-vl-chat.

This PR only concerns Qwen-VL (version 1). For Qwen2-VL, please refer to #7905.

Summary:

  • Adds multimodal input mapper/processor for Qwen models
  • Ports the visual encoder from qwen-vl/chat
  • Only only initializes the visual model and processes multimodal components if the model has a visual config
  • Enables .chat for qwen models, adds an example for qwen-vl to the offline visual language samples
  • Switches the existing Qwen test to Qwen/Qwen-7B-Chat to make sure we can still load non multimodal Qwen models

Some examples that may be helpful:

i. Running qwen-vl as a model in the offline inference vision language examples:

$ python examples/offline_inference_vision_language.py --model_type qwen_vl

Sample output The Tokyo Skytree tower is seen through cherry blossoms.

ii. Example of running a text only model:

from vllm import LLM, SamplingParams

llm = LLM(model="Qwen/Qwen-7B-Chat", trust_remote_code=True)

prompt = "<|im_start|>user\nWho were the founders of Microsoft?\n<|im_end|>\n<|im_start|>assistant\n"
stop_token_ids = None

sampling_params = SamplingParams(temperature=0.2,
                                 max_tokens=64,
                                 stop_token_ids=stop_token_ids)

inputs = [{"prompt": prompt}]
outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Sample output: Microsoft was founded by Bill Gates and Paul Allen in 1975.<|im_end|>

iii. Visual embeddings example
Multiple pictures may be passed as embeddings. In general, these should be of shape # image, 256, 4096, since Qwen-vl/chat encode images into fixed 256 token contexts. Sample and output below.

from vllm import LLM, SamplingParams
import torch

# Embeddings for 2 images (i.e., [2, 256, 4096])
# One of these images it the VLLM tokyo skytree pic, the other is
# the example used in Qwen model docs of a girl and her dog.
embeds = torch.load(...)

llm = LLM(model="Qwen/Qwen-VL-Chat", trust_remote_code=True)

# NOTE: You don't need to put anything between <img> / </img> since in VLLM,
# the loaded multimodal data is provided separately.
get_img_prompt = lambda img_num: f"Picture {img_num}: <img></img>\n"
prompt = f"<|im_start|>Picture 1: {get_img_prompt(1)} {get_img_prompt(2)} Can you compare these two pictures in english?\n<|im_end|>\n<|im_start|>assistant\n"
stop_token_ids = None

sampling_params = SamplingParams(temperature=0.2,
                                 max_tokens=64,
                                 stop_token_ids=stop_token_ids)

inputs = [{"prompt": prompt, "multi_modal_data": {"image": embeds}}]
outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Sample output: Picture 1 is of a woman sitting on the beach with her dog, both of them holding hands and smiling at each other. Picture 2 is of the Tokyo Skytree tower in Japan, surrounded by pink cherry blossom trees.<|im_end|>

iv. Chat example
Here's an example of calling qwen-vl-chat with an image with OpenAPI and the sample chatml template.

Start the server:

python vllm/entrypoints/openai/api_server.py \
    --device cuda \
    --model Qwen/Qwen-VL-Chat \
    --tokenizer Qwen/Qwen-VL-Chat \
    --trust-remote-code \
    --api-key token-abc123 \
    --chat-template examples/template_chatml.jinja &

Client example:

from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")

completion = client.chat.completions.create(
  model="Qwen/Qwen-VL-Chat",
  messages=[
    {
        "role": "user", "content": [
          {"type": "image_url", "image_url": {"url": "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"}},
          {"type": "text", "text": "Describe this image in English. "},
        ]
    }
  ]
)

print(completion.choices[0].message)

Example Response:

ChatCompletionMessage(content="A radar chart is shown with several axes, including 'VQA2v3', 'GQA', 'LmivaBench', 'SEED-Bench', 'VizWiz', 'SQA-IMG', 'MMBench-CN', 'TextVQA', 'BLIP-2', 'InstructBLIP', 'Qwen-VL-Chat', and 'LLA-VA.1.5'. Each axis has a value associated with it, with 'VQA2v3' being the highest, and 'LmivaBench' being the lowest. Some axes also have negative values.<|im_end|>\n<|im_start|>\n", refusal=None, role='assistant', function_call=None, tool_calls=[])

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@alex-jw-brooks alex-jw-brooks changed the title Qwen multimodal [MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) Aug 30, 2024
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review September 1, 2024 13:37
@alex-jw-brooks
Copy link
Contributor Author

/ready

@alex-jw-brooks
Copy link
Contributor Author

Cool, sounds good, thanks @DarkLight1337! 🤞

I saw you had resolved this comment: #8029 (comment) - I added parallel linear layers for the MLP in the visual encoder, but am still trying to rework the VisualAttention to use QKVParallelLinear and memory_efficient_attention_forward from xformers, like most of the other visual encoders implemented in VLLM currently do.

Did you want me to try to make that get that into this PR once the test is resolved, or would it be better off in a follow-up PR to optimize this model? I think the rest of changes should be taken care of 🙂

@DarkLight1337
Copy link
Member

Cool, sounds good, thanks @DarkLight1337! 🤞

I saw you had resolved this comment: #8029 (comment) - I added parallel linear layers for the MLP in the visual encoder, but am still trying to rework the VisualAttention to use QKVParallelLinear and memory_efficient_attention_forward from xformers, like most of the other visual encoders implemented in VLLM currently do.

Did you want me to try to make that get that into this PR once the test is resolved, or would it be better off in a follow-up PR to optimize this model? I think the rest of changes should be taken care of 🙂

I wanted to parallelize the MLP first as it's easier. We can parallelize the attention module layer in another PR as it's a bit more complicated.

@alex-jw-brooks
Copy link
Contributor Author

Cool, that sounds good to me!

@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 5, 2024

For the dummy data, you should pad the input with text tokens so that (combined with the image tokens) there are at least a total of seq_len tokens. You can see dummy_seq_data_for_clip for an example.

@alex-jw-brooks
Copy link
Contributor Author

Nice catch! Pushed the fix to pad it if the image prompt isn't long enough

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLM tests pass now. Thanks again for your effort!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) September 5, 2024 12:37
@DarkLight1337 DarkLight1337 merged commit 9da25a8 into vllm-project:main Sep 5, 2024
49 checks passed
@DarkLight1337
Copy link
Member

The PR has been merged. Some follow-ups to be done:

  • TP support for vision encoder, particularly the transformer module.
  • Testing multi-image input for Qwen-VL so we can officially support it in the docs.

@zhangfan-algo
Copy link

Can we support qwen2-vl-7B?

@DarkLight1337
Copy link
Member

Can we support qwen2-vl-7B?

It is WIP in #7905

@syngokhan
Copy link

I want to launch the model as an API server. But I am catching this as an error. Here I have downloaded and installed vllm in its latest form. (v.0.6.0 and after)

(image) root@gapxivrgpup03:/home/glb90108385# CUDA_VISIBLE_DEVICES=1 python /opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py --model /opt/GPT/MODELS/Qwen2-VL-7B-Instruct/ --host 10.12.112.162 --port 9002 --tensor-parallel-size 1 --trust-remote-code  --max-model-len 32000 --enforce-eager --gpu-memory-utilization 1.0
INFO 09-06 06:36:52 api_server.py:495] vLLM API server version 0.6.0
INFO 09-06 06:36:52 api_server.py:496] args: Namespace(host='10.12.112.162', port=9002, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, model='/opt/GPT/MODELS/Qwen2-VL-7B-Instruct/', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=32000, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=1.0, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=True, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
Traceback (most recent call last):
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 531, in <module>
    asyncio.run(run_server(args))
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 498, in run_server
    async with build_async_engine_client(args) as async_engine_client:
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 110, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 132, in build_async_engine_client_from_engine_args
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 73, in model_is_embedding
    return ModelConfig(model=model_name,
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 224, in __init__
    self.max_model_len = _get_and_verify_max_len(
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 1740, in _get_and_verify_max_len
    assert "factor" in rope_scaling
AssertionError

@DarkLight1337
Copy link
Member

I want to launch the model as an API server. But I am catching this as an error. Here I have downloaded and installed vllm in its latest form. (v.0.6.0 and after)

(image) root@gapxivrgpup03:/home/glb90108385# CUDA_VISIBLE_DEVICES=1 python /opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py --model /opt/GPT/MODELS/Qwen2-VL-7B-Instruct/ --host 10.12.112.162 --port 9002 --tensor-parallel-size 1 --trust-remote-code  --max-model-len 32000 --enforce-eager --gpu-memory-utilization 1.0
INFO 09-06 06:36:52 api_server.py:495] vLLM API server version 0.6.0
INFO 09-06 06:36:52 api_server.py:496] args: Namespace(host='10.12.112.162', port=9002, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, model='/opt/GPT/MODELS/Qwen2-VL-7B-Instruct/', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=32000, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=1.0, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=True, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
Traceback (most recent call last):
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 531, in <module>
    asyncio.run(run_server(args))
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 498, in run_server
    async with build_async_engine_client(args) as async_engine_client:
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 110, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 132, in build_async_engine_client_from_engine_args
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 73, in model_is_embedding
    return ModelConfig(model=model_name,
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 224, in __init__
    self.max_model_len = _get_and_verify_max_len(
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 1740, in _get_and_verify_max_len
    assert "factor" in rope_scaling
AssertionError

This PR only adds support for Qwen-VL (version 1). For Qwen2-VL, please refer to #7905 .

@zhangfan-algo
Copy link

I want to launch the model as an API server. But I am catching this as an error. Here I have downloaded and installed vllm in its latest form. (v.0.6.0 and after)

(image) root@gapxivrgpup03:/home/glb90108385# CUDA_VISIBLE_DEVICES=1 python /opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py --model /opt/GPT/MODELS/Qwen2-VL-7B-Instruct/ --host 10.12.112.162 --port 9002 --tensor-parallel-size 1 --trust-remote-code  --max-model-len 32000 --enforce-eager --gpu-memory-utilization 1.0
INFO 09-06 06:36:52 api_server.py:495] vLLM API server version 0.6.0
INFO 09-06 06:36:52 api_server.py:496] args: Namespace(host='10.12.112.162', port=9002, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, model='/opt/GPT/MODELS/Qwen2-VL-7B-Instruct/', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=32000, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=1.0, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=True, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
Traceback (most recent call last):
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 531, in <module>
    asyncio.run(run_server(args))
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/anaconda3/envs/image/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 498, in run_server
    async with build_async_engine_client(args) as async_engine_client:
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 110, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/opt/anaconda3/envs/image/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 132, in build_async_engine_client_from_engine_args
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
  File "/opt/VLLM_IMAGE/vllm/vllm/entrypoints/openai/api_server.py", line 73, in model_is_embedding
    return ModelConfig(model=model_name,
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 224, in __init__
    self.max_model_len = _get_and_verify_max_len(
  File "/opt/VLLM_IMAGE/vllm/vllm/config.py", line 1740, in _get_and_verify_max_len
    assert "factor" in rope_scaling
AssertionError

This PR only adds support for Qwen-VL (version 1). For Qwen2-VL, please refer to #7905 .

I still have the same bug after pulling down the latest github code

@DarkLight1337
Copy link
Member

It's not a bug. Qwen2-VL hasn't been added to vLLM yet. Please read my above comment.

@zhangfan-algo
Copy link

When do we expect to support qwen2 vl series?

@DarkLight1337
Copy link
Member

We are waiting for transformers to update so that we can load Qwen2-VL from their config directly.

dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 12, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
@MotorBottle
Copy link

FIX #962 FIX #7017 FIX #7192

Currently Qwen models in VLLM skip loading the visual transformer weights. This PR adds support for loading the visual weights (if they're present) and adds multimodal support, e.g., for qwen-vl and qwen-vl-chat.

This PR only concerns Qwen-VL (version 1). For Qwen2-VL, please refer to #7905.

Summary:

  • Adds multimodal input mapper/processor for Qwen models
  • Ports the visual encoder from qwen-vl/chat
  • Only only initializes the visual model and processes multimodal components if the model has a visual config
  • Enables .chat for qwen models, adds an example for qwen-vl to the offline visual language samples
  • Switches the existing Qwen test to Qwen/Qwen-7B-Chat to make sure we can still load non multimodal Qwen models

Some examples that may be helpful:

i. Running qwen-vl as a model in the offline inference vision language examples:

$ python examples/offline_inference_vision_language.py --model_type qwen_vl

Sample output The Tokyo Skytree tower is seen through cherry blossoms.

ii. Example of running a text only model:

from vllm import LLM, SamplingParams

llm = LLM(model="Qwen/Qwen-7B-Chat", trust_remote_code=True)

prompt = "<|im_start|>user\nWho were the founders of Microsoft?\n<|im_end|>\n<|im_start|>assistant\n"
stop_token_ids = None

sampling_params = SamplingParams(temperature=0.2,
                                 max_tokens=64,
                                 stop_token_ids=stop_token_ids)

inputs = [{"prompt": prompt}]
outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Sample output: Microsoft was founded by Bill Gates and Paul Allen in 1975.<|im_end|>

iii. Visual embeddings example Multiple pictures may be passed as embeddings. In general, these should be of shape # image, 256, 4096, since Qwen-vl/chat encode images into fixed 256 token contexts. Sample and output below.

from vllm import LLM, SamplingParams
import torch

# Embeddings for 2 images (i.e., [2, 256, 4096])
# One of these images it the VLLM tokyo skytree pic, the other is
# the example used in Qwen model docs of a girl and her dog.
embeds = torch.load(...)

llm = LLM(model="Qwen/Qwen-VL-Chat", trust_remote_code=True)

# NOTE: You don't need to put anything between <img> / </img> since in VLLM,
# the loaded multimodal data is provided separately.
get_img_prompt = lambda img_num: f"Picture {img_num}: <img></img>\n"
prompt = f"<|im_start|>Picture 1: {get_img_prompt(1)} {get_img_prompt(2)} Can you compare these two pictures in english?\n<|im_end|>\n<|im_start|>assistant\n"
stop_token_ids = None

sampling_params = SamplingParams(temperature=0.2,
                                 max_tokens=64,
                                 stop_token_ids=stop_token_ids)

inputs = [{"prompt": prompt, "multi_modal_data": {"image": embeds}}]
outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Sample output: Picture 1 is of a woman sitting on the beach with her dog, both of them holding hands and smiling at each other. Picture 2 is of the Tokyo Skytree tower in Japan, surrounded by pink cherry blossom trees.<|im_end|>

iv. Chat example Here's an example of calling qwen-vl-chat with an image with OpenAPI and the sample chatml template.

Start the server:

python vllm/entrypoints/openai/api_server.py \
    --device cuda \
    --model Qwen/Qwen-VL-Chat \
    --tokenizer Qwen/Qwen-VL-Chat \
    --trust-remote-code \
    --api-key token-abc123 \
    --chat-template examples/template_chatml.jinja &

Client example:

from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")

completion = client.chat.completions.create(
  model="Qwen/Qwen-VL-Chat",
  messages=[
    {
        "role": "user", "content": [
          {"type": "image_url", "image_url": {"url": "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"}},
          {"type": "text", "text": "Describe this image in English. "},
        ]
    }
  ]
)

print(completion.choices[0].message)

Example Response:

ChatCompletionMessage(content="A radar chart is shown with several axes, including 'VQA2v3', 'GQA', 'LmivaBench', 'SEED-Bench', 'VizWiz', 'SQA-IMG', 'MMBench-CN', 'TextVQA', 'BLIP-2', 'InstructBLIP', 'Qwen-VL-Chat', and 'LLA-VA.1.5'. Each axis has a value associated with it, with 'VQA2v3' being the highest, and 'LmivaBench' being the lowest. Some axes also have negative values.<|im_end|>\n<|im_start|>\n", refusal=None, role='assistant', function_call=None, tool_calls=[])

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE

PR Checklist (Click to Expand)

Hi and I followed the sample code to deploy Qwen-VL-Chat with vllm docker. While deployment was successful, I kept getting ou of vocabulary OOV errors no matter how I test my inputs.

How I deployed:

sudo docker run --runtime nvidia --gpus '"device=0,1"' --ipc=host -p 18434:8000   -v hf_cache:/root/.cache/huggingface   -d   -e HF_ENDPOINT=https://hf-mirror.com   -e HF_HUB_ENABLE_HF_TRANSFER=0   --name Qwen-VL-Chat   vllm/vllm-openai:latest   --model Qwen/Qwen-VL-Chat   --tokenizer Qwen/Qwen-VL-Chat   --tensor-parallel-size 2   --trust-remote-code   --chat-template examples/template_chatml.jinja   --dtype='half'

Error msg:

Error in API call: 400 {"object":"error","message":"Token id 151859 is out of vocabulary","type":"BadRequestError","param":null,"code":400}

Test code:

import requests
import base64
import time

# Function to encode the image to base64
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

def main():
    # Path to your image
    image_path = "test2.jpg"
    base64_image = encode_image(image_path)

    # API configuration
    api_base = "http://192.168.50.18:18434/v1/chat/completions"
    model_name = "Qwen/Qwen-VL-Chat"

    # Input prompt
    user_prompt_text = (
        "What's inside the image?"
    )

    # Prepare the payload
    payload_template = {
        "model": model_name,
        "messages": [
            {
                "role": "user",
                "content": [
                    # {"type": "image_url", "image_url": {"url": "https://i.imgur.com/T3S0cvu.jpeg"}},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
                    {"type": "text", "text": user_prompt_text}
                ]
            }
        ],
        "max_tokens": 300
    }

    for i in range(1, 2):
        print(f"===== API called {i} times =====")
        startTime = time.time()

        response = requests.post(api_base, json=payload_template)

        if response.status_code != 200:
            print("Error in API call:", response.status_code, response.text)
        else:
            completion = response.json()["choices"][0]["message"]["content"]
            tokens = response.json()["usage"]["prompt_tokens"]
            print("Model Response:", completion)
            print("tokens:", tokens)

        print("time used: {:.2f} 秒".format(time.time() - startTime))
        print()

if __name__ == "__main__":
    main()

I tried to search the whole observable web and could not find any similar case. So I'm replying here for possible help.

Much appreciated!

LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…t#8029)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Signed-off-by: LeiWang1999 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Model]: Support for Qwen-VL model [Feature]: Not support Qwen-VL-Chat can model Qwen/Qwen-VL-Chat work well?

5 participants