Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Jul 18, 2024

FIX #6545

Patch was ported from huggingface/transformers#32050

Essentially there was a new head_dim override added to MistralConfig. We will look for that optional argument in the config and default to the previous self.hidden_size // self.total_num_heads behavior.

We have also produced and validated a FP8 quantized checkpoint: https://huggingface.co/neuralmagic/Mistral-Nemo-Instruct-2407-FP8

>>> from vllm import LLM
>>> model = LLM("mistralai/Mistral-Nemo-Instruct-2407", max_model_len=4096)
>>> model.generate("Hello!")
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.79it/s, est. speed input: 19.19 toks/s, output: 76.75 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[1, 22177, 1033, 2], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='Hello! How can I assist you today? Let me know if you have any', token_ids=(22177, 1033, 3075, 1710, 1362, 10410, 1636, 9406, 1063, 9246, 1639, 2840, 1693, 1636, 1736, 2258), cumulative_logprob=-1.838211446563946, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721321536.0599637, last_token_time=1721321536.0599637, first_scheduled_time=1721321536.0735824, first_token_time=1721321536.111321, time_in_queue=0.013618707656860352, finished_time=1721321536.2816722), lora_request=None)]

Note that by default it will use a very large model length (128k) and may need max_model_len to be specified.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2024
@w013nad
Copy link

w013nad commented Jul 18, 2024

Tested merging your commits into 0.5.2 and it works fine. Model works up to 100k tokens(max I can fit into my A100 with fp8 weights/fp8 cache.

@mgoin mgoin enabled auto-merge (squash) July 18, 2024 19:35
@mgoin mgoin merged commit 15c6a07 into vllm-project:main Jul 18, 2024
@mgoin mgoin deleted the mistral-nemo-support branch July 18, 2024 20:31
@jasonacox
Copy link
Contributor

Thanks @mgoin !! @simon-mo Are we able to get this in the upcoming release?

@simon-mo
Copy link
Collaborator

Yes!

@maxin9966
Copy link

@mgoin

env:
vllm 0.5.2

VLLM_ATTENTION_BACKEND=XFORMERS CUDA_VISIBLE_DEVICES=0,1 python -m vllm.entrypoints.openai.api_server --model neuralmagic/Mistral-Nemo-Instruct-2407-FP8 --gpu-memory-utilization 0.75 --quantization fp8 --host 0.0.0.0 --port 1237 -tp 2 --max-model-len 17000 --served-model-name gpt --trust-remote-code --enable-prefix-caching

error:
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: [rank0]: File "/home/ma/miniconda3/envs/myenv/lib/python3.9/site-packages/vllm/attention/layer.py", line 82, in init
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: [rank0]: self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: [rank0]: File "/home/ma/miniconda3/envs/myenv/lib/python3.9/site-packages/vllm/attention/backends/xformers.py", line 419, in init
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: [rank0]: raise ValueError(
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: [rank0]: ValueError: Head size 160 is not supported by PagedAttention. Supported head sizes are: [64, 80, 96, 112, 128, 192, 256].
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: ERROR 07-19 16:20:41 multiproc_worker_utils.py:120] Worker VllmWorkerProcess pid 12793 died, exit code: -15
7月 19 16:20:41 ma-MS-TZZ-Z690M bash[12771]: INFO 07-19 16:20:41 multiproc_worker_utils.py:123] Killing local vLLM worker processes

@jasonacox
Copy link
Contributor

@maxin9966 You would need to apply the patch manually. It hasn't been released yet. See https://docs.vllm.ai/en/latest/getting_started/installation.html#build-from-source.

I'm running mistralai/Mistral-Nemo-Instruct-2407 on an A100 with 100k and no issues. Built via docker...

DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag nemo-vllm

@maxin9966
Copy link

@jasonacox Alright, thank you very much.

@mgoin Could you please confirm if the latest code supports the mistral-nemo models running in gptq or awq modes? FP8 is a bit too slow.

@mgoin
Copy link
Member Author

mgoin commented Jul 19, 2024

Yes, mistral-nemo should have the same quantization support as mistral.

@vgoklani
Copy link

@w013nad how did you test fp8 with an A100? I thought fp8 was only supported on newer hardware. thanks!

@mgoin
Copy link
Member Author

mgoin commented Jul 19, 2024

@vgoklani We support FP8 weight-only quantization on >=Ampere GPUs through the FP8 Marlin kernel for decoding speedup #5975

@dionren
Copy link

dionren commented Jul 19, 2024

Tested merging your commits into 0.5.2 and it works fine. Model works up to 100k tokens(max I can fit into my A100 with fp8 weights/fp8 cache.

How much GPU memory is needed by fp16 model and 128K tokens?

@jasonacox
Copy link
Contributor

Testing a single A100, 128k max-model-len, dtype=auto, weights take 23GB but full vram running footprint is 57GB. I'm getting average 42 TPS per session with aggregate throughput of 1,422 TPS using 512 concurrent threads (load testing).

Docker:

git clone https://github.com/vllm-project/vllm.git
cd vllm
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm-nemo
docker run -d --runtime nvidia --gpus '"device=0"' \
    -v ${PWD}/models:/root/.cache/huggingface \
    -p 8000:8000 \
    -e NVIDIA_DISABLE_REQUIRE=true \
    --env "HF_TOKEN=*******" \
    --ipc=host \
    --name vllm \
    --restart unless-stopped \
    vllm-nemo \
    --model mistralai/Mistral-Nemo-Instruct-2407 \
    --max-model-len 128000 \
    --tensor-parallel-size 1

@tensimixt
Copy link

tensimixt commented Jul 19, 2024

Testing a single A100, 128k max-model-len, dtype=auto, weights take 23GB but full vram running footprint is 57GB. I'm getting average 42 TPS per session with aggregate throughput of 1,422 TPS using 512 concurrent threads (load testing).

Docker:

git clone https://github.com/vllm-project/vllm.git
cd vllm
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm-nemo
docker run -d --runtime nvidia --gpus '"device=0"' \
    -v ${PWD}/models:/root/.cache/huggingface \
    -p 8000:8000 \
    -e NVIDIA_DISABLE_REQUIRE=true \
    --env "HF_TOKEN=*******" \
    --ipc=host \
    --name vllm \
    --restart unless-stopped \
    vllm-nemo \
    --model mistralai/Mistral-Nemo-Instruct-2407 \
    --max-model-len 128000 \
    --tensor-parallel-size 1

This is great! Do you if know if it will work like this with a LoRA Adapter currently?

@tensimixt
Copy link

tensimixt commented Jul 19, 2024

Tested with FP8 on 2A100s getting 86.60 tok/s
https://huggingface.co/FlorianJc/Mistral-Nemo-Instruct-2407-vllm-fp8/tree/main

@jasonacox
Copy link
Contributor

@tensimixt I would love to see what aggregate (concurrent) tok/s you get with that setup. I use this simple load generator: https://github.com/jasonacox/TinyLLM/blob/main/loadtest.py

@RonanKMcGovern
Copy link
Contributor

@simon-mo the latest docker image will include this next week? Thanks

@dionren
Copy link

dionren commented Jul 21, 2024

Need help, why I can't use fp8?

docker run -d --ipc host \
  --gpus '"device=0"' \
  -v /mnt/cpn-nvme/b11d5292-85ab-e9a4-7eca-31614bb76c91:/mnt/cpn-pod \
  -p 8600:8000 \
  192.168.200.5/pod/vllm-nemo:0.0.1 \
  --max-model-len 131072 \
  --gpu-memory-utilization 0.98 \
  --kv-cache-dtype fp8 \
  --quantization fp8 \
  --model /mnt/cpn-pod/models/FlorianJc/Mistral-Nemo-Instruct-2407-vllm-fp8 \
  --served-model-name FlorianJc/Mistral-Nemo-Instruct-2407-vllm-fp8 \
  --tensor-parallel-size 1
INFO 07-21 03:58:49 selector.py:161] Cannot use FlashAttention-2 backend for FP8 KV cache.
INFO 07-21 03:58:49 selector.py:54] Using XFormers backend.
WARNING 07-21 03:58:51 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
INFO 07-21 03:58:51 selector.py:161] Cannot use FlashAttention-2 backend for FP8 KV cache.
INFO 07-21 03:58:51 selector.py:54] Using XFormers backend.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 292, in <module>
[rank0]:     run_server(args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 230, in run_server
[rank0]:     if llm_engine is not None else AsyncLLMEngine.from_engine_args(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 464, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 378, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 546, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 250, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 47, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 36, in _init_executor
[rank0]:     self.driver_worker.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 139, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 553, in load_model
[rank0]:     self.model = get_model(model_config=self.model_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 21, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 289, in load_model
[rank0]:     quant_method.process_weights_after_loading(module)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/fp8.py", line 449, in process_weights_after_loading
[rank0]:     raise ValueError("Only support per-tensor scaling factor "
[rank0]: ValueError: Only support per-tensor scaling factor for fp8 KV cache

@vladfaust
Copy link

I'm getting [rank0]: RuntimeError: start (0) + length (1280) exceeds dimension size (1024).\n on vllm 0.5.5 with https://huggingface.co/Alex01837178373/Vikhr-Nemo-12B-Instruct-R-21-09-24-Q5_K_M-GGUF. Looks like it doesn't work with GGUF quantization?

@jasonacox
Copy link
Contributor

Ensure that your input sequence length doesn’t exceed the model’s maximum limit. Trim or truncate the input to fit within 1024 tokens.

@vladfaust
Copy link

Ensure that your input sequence length doesn’t exceed the model’s maximum limit. Trim or truncate the input to fit within 1024 tokens.

@jasonacox

Initialization config
INFO 11-08 03:53:37 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='/runpod-volume/huggingface-cache/hub/models--Alex01837178373--Vikhr-Nemo-12B-Instruct-R-21-09-24-Q5_K_M-GGUF/snapshots/ae85bc7602673866ba86e05f0b91cecbe52fffd9/vikhr-nemo-12b-instruct-r-21-09-24-q5_k_m.gguf', speculative_config=None, tokenizer='Vikhrmodels/Vikhr-Nemo-12B-Instruct-R-21-09-24', skip_tokenizer_init=False, tokenizer_mode=auto, revision=ae85bc7602673866ba86e05f0b91cecbe52fffd9, rope_scaling=None, rope_theta=None, tokenizer_revision=7499757d9f41a2965b0e9db94c976e0982e292b8, trust_remote_code=False, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.GGUF, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gguf, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/runpod-volume/huggingface-cache/hub/models--Alex01837178373--Vikhr-Nemo-12B-Instruct-R-21-09-24-Q5_K_M-GGUF/snapshots/ae85bc7602673866ba86e05f0b91cecbe52fffd9/vikhr-nemo-12b-instruct-r-21-09-24-q5_k_m.gguf, use_v2_block_manager=False, enable_prefix_caching=True)

max_seq_len=16384

From model's config.json:

"max_position_embeddings": 1024000.

Input length is irrelevant, because it's Error initializing vLLM engine: start (0) + length (1280) exceeds dimension size (1024).

@mgoin
Copy link
Member Author

mgoin commented Nov 8, 2024

@Isotr0py would you have an idea about GGUF issues with this architecture?

@Isotr0py
Copy link
Member

Isotr0py commented Nov 9, 2024

@vladfaust Can you try updating to the latest 0.6.3.post1 vllm? I can load this model with latest vllm.

@vladfaust
Copy link

@Isotr0py yep, it works with the latest vLLM.

LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: mistralai/Mistral-Nemo-Instruct-2407 support