Skip to content

[Bug]: GLM-4.1V lora trained model reports target_module mismatch error #22077

@miridih-jyjang

Description

@miridih-jyjang

Your current environment

This model I finetuned using LLaMA-Factory using GLM-4.1V with Lora. The LoRA config is,

model

model_name_or_path: zai-org/GLM-4.1V-9B-Thinking #/data/shared/checkpoints/hugging_face/GLM-4.1V-9B-Thinking
#model_name_or_path: saves/glm4_1_v_9b/lora/version020/checkpoint-2000
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true

method

stage: sft
do_train: true
finetuning_type: lora
lora_rank: 32
lora_target: all
deepspeed: examples/deepspeed/ds_z3_offload_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]

dataset

dataset: mllm_dataset #,identity,alpaca_en_demo # video: mllm_video_demo
template: glm4v
cutoff_len: 15000
max_samples: 1000000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4

output

output_dir: saves/glm4_1_v_9b/lora/version020
logging_steps: 10
save_steps: 2000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
save_strategy: steps #
metric_for_best_model: loss #
#save_total_limit: 4
report_to: wandb
run_name: glm4_1_v_9b_lora_sft_v020
load_best_model_at_end: true #

train

per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 4.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null

eval

eval_dataset: mllm_eval_dataset
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 100
do_predict: false
predict_with_generate: false

The output of DISABLE_VERSION_CHECK=1 python scripts/vllm_infer.py --model_name_or_path zai-org/GLM-4.1V-9B-Thinking --template glm4v --dataset mllm_eval_dataset --adapter_name_or_path /data/shared/jyjang /LLaMA-Factory/saves/glm4_1_v_9b/lora/version020/checkpoint-6000/ --vllm_config '{"max_lora_rank": 32}'
Your output of `DISABLE_VERSION_CHECK=1 python scripts/vllm_infer.py --model_name_or_path zai-org/GLM-4.1V-9B-Thinking --template glm4v --dataset mllm_eval_dataset --adapter_name_or_path /data/shared/jyjang
/LLaMA-Factory/saves/glm4_1_v_9b/lora/version020/checkpoint-6000/ --vllm_config '{"max_lora_rank": 32}'                                                                                                                                                                     ` here
root@dbc3e0868142:/data/shared/jyjang/LLaMA-Factory# CUDA_VISIBLE_DEVICES=0,1 DISABLE_VERSION_CHECK=1 python scripts/vllm_infer.py --model_name_or_path zai-org/GLM-4.1V-9B-Thinking --template glm4v --dataset mllm_eval_dataset --adapter_name_or_path /data/shared/jyjang /LLaMA-Factory/saves/glm4_1_v_9b/lora/version020/checkpoint-6000/ --vllm_config '{"max_lora_rank": 32}' ... ) - video_processor: Glm4vVideoProcessor { "crop_size": null, "data_format": "channels_first", "default_to_square": true, "device": null, "do_center_crop": null, "do_convert_rgb": true, "do_normalize": true, "do_pad": null, "do_rescale": true, "do_resize": true, "do_sample_frames": true, "fps": 2, "image_mean": [ 0.48145466, 0.4578275, 0.40821073 ], "image_std": [ 0.26862954, 0.26130258, 0.27577711 ], "input_data_format": null, "max_image_size": { "longest_edge": 47040000 }, "merge_size": 2, "num_frames": 16, "patch_size": 14, "processor_class": "Glm4vProcessor", "resample": 3, "rescale_factor": 0.00392156862745098, "size": { "longest_edge": 47040000, "shortest_edge": 12544 }, "size_divisor": null, "temporal_patch_size": 2, "video_metadata": null, "video_processor_type": "Glm4vVideoProcessor" }

{
"processor_class": "Glm4vProcessor"
}

                                                                                                                                                                                                                                                                       (

VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] WorkerProc hit an exception. | 1/50 [00:07<06:11, 7.57s/it]
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/gpu_model_runner.py", line 824, in _prepare_inputs 17:39:02 [275/1868]
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] self.set_active_loras(self.input_batch, num_scheduled_tokens)
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/lora_model_runner_mixin.py", line 84, in set_active_loras
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/lora_model_runner_mixin.py", line 73, in _set_active_loras
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/worker_manager.py", line 167, in set_active_adapters
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] WorkerProc hit an exception.
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] set_active_adapters_worker(requests, mapping, self._apply_adapters,
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/adapter_commons/utils.py", line 55, in set_active_adapters_worker
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] Traceback (most recent call last):
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] apply_adapters_func(requests)
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py", line 541, in worker_busy_loop
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/worker_manager.py", line 227, in _apply_adapters
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] output = func(*args, **kwargs)
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] self.add_adapter(lora)
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/worker_manager.py", line 240, in add_adapter
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] lora = self._load_adapter(lora_request)
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] return func(*args, **kwargs)
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/worker_manager.py", line 141, in _load_adapter
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/gpu_worker.py", line 337, in execute_model
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] raise e
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/worker_manager.py", line 116, in _load_adapter
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] lora = self._lora_model_cls.from_local_checkpoint(
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] return func(*args, **kwargs)
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/models.py", line 255, in from_local_checkpoint
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] check_unexpected_modules(f)
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1366, in execute_model
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] File "/opt/conda/lib/python3.11/site-packages/vllm/lora/models.py", line 225, in check_unexpected_modules
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] self._prepare_inputs(scheduler_output))
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] raise ValueError(
(VllmWorker rank=1 pid=62315) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=62314) ERROR 08-01 08:39:02 [multiproc_executor.py:546] ValueError: While loading /data/shared/jyjang/LLaMA-Factory/saves/glm4_1_v_9b/lora/version020/checkpoint-6000/, expected target modules in ['v_proj', 'k_proj', 'up_proj', 'qkv', 'q_proj', '
o_proj', 'down_proj', 'gate_proj', 'proj'] but received ['language_model.model.layers.0.mlp.gate_up_proj', 'language_model.model.layers.0.mlp.gate_up_proj', 'language_model.model.layers.1.mlp.gate_up_proj', 'language_model.model.layers.1.mlp.gate_up_proj', 'language_m
odel.model.layers.10.mlp.gate_up_proj', 'language_model.model.layers.10.mlp.gate_up_proj', 'language_model.model.layers.11.mlp.gate_up_proj', 'language_model.model.layers.11.mlp.gate_up_proj', 'language_model.model.layers.12.mlp.gate_up_proj', 'language_model.model.la
yers.12.mlp.gate_up_proj', 'language_model.model.layers.13.mlp.gate_up_proj', 'language_model.model.layers.13.mlp.gate_up_proj', 'language_model.model.layers.14.mlp.gate_up_proj', 'language_model.model.layers.14.mlp.gate_up_proj', 'language_model.model.layers.15.mlp.g
ate_up_proj', 'language_model.model.layers.15.mlp.gate_up_proj', 'language_model.model.layers.16.mlp.gate_up_proj', 'language_model.model.layers.16.mlp.gate_up_proj', 'language_model.model.layers.17.mlp.gate_up_proj', 'language_model.model.layers.17.mlp.gate_up_proj',
'language_model.model.layers.18.mlp.gate_up_proj', 'language_model.model.layers.18.mlp.gate_up_proj', 'language_model.model.layers.19.mlp.gate_up_proj', 'language_model.model.layers.19.mlp.gate_up_proj', 'language_model.model.layers.2.mlp.gate_up_proj', 'language_mod
el.model.layers.2.mlp.gate_up_proj', 'language_model.model.layers.20.mlp.gate_up_proj', 'language_model.model.layers.20.mlp.gate_up_proj', 'language_model.model.layers.21.mlp.gate_up_proj', 'language_model.model.layers.21.mlp.gate_up_proj', 'language_model.model.layer
s.22.mlp.gate_up_proj', 'language_model.model.layers.22.mlp.gate_up_proj', 'language_model.model.layers.23.mlp.gate_up_proj', 'language_model.model.layers.23.mlp.gate_up_proj', 'language_model.model.layers.24.mlp.gate_up_proj', 'language_model.model.layers.24.mlp.gate
_up_proj', 'language_model.model.layers.25.mlp.gate_up_proj', 'language_model.model.layers.25.mlp.gate_up_proj', 'language_model.model.layers.26.mlp.gate_up_proj', 'language_model.model.layers.26.mlp.gate_up_proj', 'language_model.model.layers.27.mlp.gate_up_proj', 'l
anguage_model.model.layers.27.mlp.gate_up_proj', 'language_model.model.layers.28.mlp.gate_up_proj', 'language_model.model.layers.28.mlp.gate_up_proj', 'language_model.model.layers.29.mlp.gate_up_proj', 'language_model.model.layers.29.mlp.gate_up_proj', 'language_model
.model.layers.3.mlp.gate_up_proj', 'language_model.model.layers.3.mlp.gate_up_proj', 'language_model.model.layers.30.mlp.gate_up_proj', 'language_model.model.layers.30.mlp.gate_up_proj', 'language_model.model.layers.31.mlp.gate_up_proj', 'language_model.model.layers.3
1.mlp.gate_up_proj', 'language_model.model.layers.32.mlp.gate_up_proj', 'language_model.model.layers.32.mlp.gate_up_proj', 'language_model.model.layers.33.mlp.gate_up_proj', 'language_model.model.layers.33.mlp.gate_up_proj', 'language_model.model.layers.34.mlp.gate_up
_proj', 'language_model.model.layers.34.mlp.gate_up_proj', 'language_model.model.layers.35.mlp.gate_up_proj', 'language_model.model.layers.35.mlp.gate_up_proj', 'language_model.model.layers.36.mlp.gate_up_proj', 'language_model.model.layers.36.mlp.gate_up_proj', 'lang
uage_model.model.layers.37.mlp.gate_up_proj', 'language_model.model.layers.37.mlp.gate_up_proj', 'language_model.model.layers.38.mlp.gate_up_proj', 'language_model.model.layers.38.mlp.gate_up_proj', 'language_model.model.layers.39.mlp.gate_up_proj', 'language_model.m$
del.layers.39.mlp.gate_up_proj', 'language_model.model.layers.4.mlp.gate_up_proj', 'language_model.model.layers.4.mlp.gate_up_proj', 'language_model.model.layers.5.mlp.gate_up_proj', 'language_model.model.layers.5.mlp.gate_up_proj', 'language_model.model.layers.6.mlp.
gate_up_proj', 'language_model.model.layers.6.mlp.gate_up_proj', 'language_model.model.layers.7.mlp.gate_up_proj', 'language_model.model.layers.7.mlp.gate_up_proj', 'language_model.model.layers.8.mlp.gate_up_proj', 'language_model.model.layers.8.mlp.gate_up_proj', 'la
nguage_model.model.layers.9.mlp.gate_up_proj', 'language_model.model.layers.9.mlp.gate_up_proj']. Please verify that the loaded LoRA module is correct

🐛 Describe the bug

Title:
ValueError during LoRA inference: Unexpected target modules in checkpoint

Issue Description:
I’m encountering a ValueError during inference using vllm_infer.py when attempting to load a LoRA fine-tuned checkpoint of zai-org/GLM-4.1V-9B-Thinking trained using LLaMA-Factory. The model was trained with the following settings:

lora_target: all
lora_rank: 32
stage: sft
finetuning_type: lora

When launching inference with this command:

CUDA_VISIBLE_DEVICES=0,1 DISABLE_VERSION_CHECK=1 python scripts/vllm_infer.py \
  --model_name_or_path zai-org/GLM-4.1V-9B-Thinking \
  --template glm4v \
  --dataset mllm_eval_dataset \
  --adapter_name_or_path /data/shared/jyjang/LLaMA-Factory/saves/glm4_1_v_9b/lora/version020/checkpoint-6000/ \
  --vllm_config '{"max_lora_rank": 32}'

I receive the following error:

ValueError: While loading /.../checkpoint-6000/, expected target modules in
['v_proj', 'k_proj', 'up_proj', 'qkv', 'q_proj', 'o_proj', 'down_proj', 'gate_proj', 'proj']
but received
['language_model.model.layers.0.mlp.gate_up_proj',
...
'language_model.model.layers.39.mlp.gate_up_proj']

Notes:
• The full list of unexpected modules includes many entries like language_model.model.layers.X.mlp.gate_up_proj (duplicated per layer).
• I suspect this may be due to mismatch in expected LoRA target module names between the vLLM runtime and the way LLaMA-Factory saves adapter weights for this model.
• In LLaMA-Factory, lora_target: all is used, and the model architecture likely names modules differently than the default vLLM expectations.

Environment:
• vLLM version: 0.10.0 / llamafactory: 0.9.4.dev0 / transformers: 4.54.1
• Model: zai-org/GLM-4.1V-9B-Thinking
• Fine-tuning framework: LLaMA-Factory
• Adapter: LoRA checkpoint @ checkpoint-6000
• LoRA Rank: 32
• Launch device: CUDA_VISIBLE_DEVICES=0,1

Request for Help:
• Is there a way to customize or override the expected target module names during LoRA adapter loading in vLLM?
• Alternatively, is there a way to align the naming between LLaMA-Factory and vLLM for LoRA-compatible adapters?
• Would using gate_up_proj as part of lora_target explicitly help here?

Thanks in advance!

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions