Skip to content

Commit 51d7c6a

Browse files
mgoinDarkLight1337
andauthored
[Model] Support Mistral3 in the HF Transformers format (#15505)
Signed-off-by: mgoin <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: DarkLight1337 <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent f3aca1e commit 51d7c6a

File tree

9 files changed

+723
-4
lines changed

9 files changed

+723
-4
lines changed

docs/source/models/supported_models.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,13 @@ See [this page](#generative-models) for more information on how to use generativ
865865
* ✅︎
866866
* ✅︎
867867
* ✅︎
868+
- * `Mistral3ForConditionalGeneration`
869+
* Mistral3
870+
* T + I<sup>+</sup>
871+
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
872+
*
873+
* ✅︎
874+
*
868875
- * `MllamaForConditionalGeneration`
869876
* Llama 3.2
870877
* T + I<sup>+</sup>

examples/offline_inference/vision_language.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,29 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
498498
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
499499

500500

501+
# Mistral-3 HF-format
502+
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
503+
assert modality == "image"
504+
505+
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
506+
507+
# NOTE: Need L40 (or equivalent) to avoid OOM
508+
engine_args = EngineArgs(
509+
model=model_name,
510+
max_model_len=8192,
511+
max_num_seqs=2,
512+
tensor_parallel_size=2,
513+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
514+
)
515+
516+
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
517+
518+
return ModelRequestData(
519+
engine_args=engine_args,
520+
prompts=prompts,
521+
)
522+
523+
501524
# LLama 3.2
502525
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
503526
assert modality == "image"
@@ -859,6 +882,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
859882
"mantis": run_mantis,
860883
"minicpmo": run_minicpmo,
861884
"minicpmv": run_minicpmv,
885+
"mistral3": run_mistral3,
862886
"mllama": run_mllama,
863887
"molmo": run_molmo,
864888
"NVLM_D": run_nvlm_d,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,28 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
218218
)
219219

220220

221+
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
222+
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
223+
224+
# Adjust this as necessary to fit in GPU
225+
engine_args = EngineArgs(
226+
model=model_name,
227+
max_model_len=8192,
228+
max_num_seqs=2,
229+
tensor_parallel_size=2,
230+
limit_mm_per_prompt={"image": len(image_urls)},
231+
)
232+
233+
placeholders = "[IMG]" * len(image_urls)
234+
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
235+
236+
return ModelRequestData(
237+
engine_args=engine_args,
238+
prompt=prompt,
239+
image_data=[fetch_image(url) for url in image_urls],
240+
)
241+
242+
221243
def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
222244
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
223245

@@ -509,6 +531,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
509531
"h2ovl_chat": load_h2ovl,
510532
"idefics3": load_idefics3,
511533
"internvl_chat": load_internvl,
534+
"mistral3": load_mistral3,
512535
"mllama": load_mllama,
513536
"NVLM_D": load_nvlm_d,
514537
"phi3_v": load_phi3v,

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def check_available_online(
297297
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
298298
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
299299
trust_remote_code=True),
300+
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
301+
min_transformers_version="4.50", # noqa: E501
302+
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
300303
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
301304
max_transformers_version="4.48",
302305
transformers_version_reason="Use of private method which no longer exists.", # noqa: E501

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ def _placeholder_str(self, modality: ModalityStr,
487487
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
488488
if model_type in ("minicpmo", "minicpmv"):
489489
return "(<image>./</image>)"
490-
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral"):
490+
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral",
491+
"mistral3"):
491492
# These models do not use image tokens in the prompt
492493
return None
493494
if model_type == "qwen":

0 commit comments

Comments
 (0)