From f801d108c13c3805916dce5fc2e037ba2e9a00be Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Tue, 20 May 2025 11:16:36 +0200
Subject: [PATCH 1/4] Convert `examples` to `ruff format`
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
---
.pre-commit-config.yaml | 2 +-
examples/pyproject.toml | 54 +++++++++++++++++++++++++++++++++++++++++
pyproject.toml | 2 ++
3 files changed, 57 insertions(+), 1 deletion(-)
create mode 100644 examples/pyproject.toml
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f5c0c368d578..b5d23ced4376 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,7 +17,7 @@ repos:
- id: ruff
args: [--output-format, github, --fix]
- id: ruff-format
- files: ^(.buildkite|benchmarks)/.*
+ files: ^(.buildkite|benchmarks|examples)/.*
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
diff --git a/examples/pyproject.toml b/examples/pyproject.toml
new file mode 100644
index 000000000000..f825cb203269
--- /dev/null
+++ b/examples/pyproject.toml
@@ -0,0 +1,54 @@
+# This local pyproject file is part of the migration from yapf to ruff format.
+# It uses the same core rules as the main pyproject.toml file, but with the
+# following differences:
+# - ruff line length is overridden to 88
+# - deprecated typing ignores (UP006, UP035) have been removed
+
+[tool.ruff]
+line-length = 88
+exclude = [
+ # External file, leaving license intact
+ "examples/other/fp8/quantizer/quantize.py",
+ "vllm/vllm_flash_attn/flash_attn_interface.pyi"
+]
+
+[tool.ruff.lint.per-file-ignores]
+"vllm/third_party/**" = ["ALL"]
+"vllm/version.py" = ["F401"]
+"vllm/_version.py" = ["ALL"]
+
+[tool.ruff.lint]
+select = [
+ # pycodestyle
+ "E",
+ # Pyflakes
+ "F",
+ # pyupgrade
+ "UP",
+ # flake8-bugbear
+ "B",
+ # flake8-simplify
+ "SIM",
+ # isort
+ "I",
+ # flake8-logging-format
+ "G",
+]
+ignore = [
+ # star imports
+ "F405", "F403",
+ # lambda expression assignment
+ "E731",
+ # Loop control variable not used within loop body
+ "B007",
+ # f-string format
+ "UP032",
+ # Can remove once 3.10+ is the minimum Python version
+ "UP007",
+]
+
+[tool.ruff.lint.isort]
+known-first-party = ["vllm"]
+
+[tool.ruff.format]
+docstring-code-format = true
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 0b803a26b658..7204b945979e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,6 +56,7 @@ ignore_patterns = [
".buildkite/**",
"benchmarks/**",
"build/**",
+ "examples/**",
]
[tool.ruff]
@@ -148,6 +149,7 @@ skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora
skip_glob = [
".buildkite/*",
"benchmarks/*",
+ "examples/*",
]
use_parentheses = true
skip_gitignore = true
From 66798a3ea4d48234aafa8930b43a5059b8a72e03 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Tue, 20 May 2025 11:33:09 +0200
Subject: [PATCH 2/4] `pre-commit run -a`
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
---
examples/lmcache/cpu_offload_lmcache.py | 25 +-
examples/lmcache/disagg_prefill_lmcache_v0.py | 59 +--
.../disagg_proxy_server.py | 100 ++--
.../lmcache/kv_cache_sharing_lmcache_v1.py | 41 +-
examples/offline_inference/audio_language.py | 150 +++---
examples/offline_inference/basic/chat.py | 18 +-
examples/offline_inference/basic/classify.py | 15 +-
examples/offline_inference/basic/embed.py | 14 +-
examples/offline_inference/basic/score.py | 6 +-
.../offline_inference/batch_llm_inference.py | 22 +-
examples/offline_inference/chat_with_tools.py | 110 ++---
examples/offline_inference/data_parallel.py | 111 +++--
.../decode_example.py | 23 +-
.../prefill_example.py | 19 +-
.../disaggregated_prefill.py | 45 +-
examples/offline_inference/eagle.py | 50 +-
.../embed_jina_embeddings_v3.py | 19 +-
.../offline_inference/embed_matryoshka_fy.py | 15 +-
examples/offline_inference/encoder_decoder.py | 40 +-
.../encoder_decoder_multimodal.py | 62 +--
.../offline_inference/llm_engine_example.py | 34 +-
.../offline_inference/load_sharded_state.py | 38 +-
.../lora_with_quantization_inference.py | 116 ++---
examples/offline_inference/mistral-small.py | 64 +--
examples/offline_inference/mlpspeculator.py | 11 +-
.../offline_inference/multilora_inference.py | 86 ++--
examples/offline_inference/neuron.py | 3 +-
examples/offline_inference/neuron_eagle.py | 6 +-
.../neuron_int8_quantization.py | 9 +-
.../offline_inference/neuron_speculation.py | 10 +-
examples/offline_inference/prefix_caching.py | 21 +-
.../prithvi_geospatial_mae.py | 179 ++++----
examples/offline_inference/profiling.py | 225 +++++----
.../profiling_tpu/profiling.py | 69 +--
.../qwen2_5_omni/only_thinker.py | 124 ++---
examples/offline_inference/qwen_1m.py | 30 +-
examples/offline_inference/rlhf.py | 21 +-
examples/offline_inference/rlhf_colocate.py | 23 +-
examples/offline_inference/rlhf_utils.py | 32 +-
.../offline_inference/save_sharded_state.py | 41 +-
.../offline_inference/structured_outputs.py | 43 +-
.../offline_inference/torchrun_example.py | 3 +-
examples/offline_inference/tpu.py | 10 +-
examples/offline_inference/vision_language.py | 427 +++++++++---------
.../vision_language_embedding.py | 71 +--
.../vision_language_multi_image.py | 425 +++++++++--------
examples/online_serving/api_client.py | 22 +-
.../online_serving/cohere_rerank_client.py | 17 +-
.../disagg_proxy_demo.py | 188 ++++----
.../gradio_openai_chatbot_webserver.py | 82 ++--
examples/online_serving/gradio_webserver.py | 32 +-
.../online_serving/jinaai_rerank_client.py | 12 +-
.../online_serving/kv_events_subscriber.py | 16 +-
.../openai_chat_completion_client.py | 23 +-
...i_chat_completion_client_for_multimodal.py | 257 +++++------
...penai_chat_completion_client_with_tools.py | 147 +++---
...t_completion_client_with_tools_required.py | 61 +--
...enai_chat_completion_structured_outputs.py | 80 ++--
...etion_structured_outputs_structural_tag.py | 44 +-
...etion_structured_outputs_with_reasoning.py | 59 ++-
...at_completion_tool_calls_with_reasoning.py | 158 +++----
.../openai_chat_completion_with_reasoning.py | 12 +-
...hat_completion_with_reasoning_streaming.py | 4 +-
...ai_chat_embedding_client_for_multimodal.py | 127 +++---
.../openai_classification_client.py | 4 +-
.../openai_completion_client.py | 3 +-
.../openai_cross_encoder_score.py | 13 +-
.../online_serving/openai_embedding_client.py | 2 +-
.../online_serving/openai_pooling_client.py | 21 +-
.../openai_transcription_client.py | 30 +-
.../opentelemetry/dummy_client.py | 9 +-
examples/online_serving/ray_serve_deepseek.py | 4 +-
...val_augmented_generation_with_langchain.py | 113 ++---
...al_augmented_generation_with_llamaindex.py | 97 ++--
.../streamlit_openai_chatbot_webserver.py | 55 +--
examples/online_serving/utils.py | 6 +-
examples/other/tensorize_vllm_model.py | 8 +-
77 files changed, 2421 insertions(+), 2350 deletions(-)
diff --git a/examples/lmcache/cpu_offload_lmcache.py b/examples/lmcache/cpu_offload_lmcache.py
index eedb47dfc12e..f76fdc987d6b 100644
--- a/examples/lmcache/cpu_offload_lmcache.py
+++ b/examples/lmcache/cpu_offload_lmcache.py
@@ -20,6 +20,7 @@
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
"""
+
import argparse
import contextlib
import os
@@ -28,7 +29,6 @@
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
-
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs
@@ -49,8 +49,7 @@ def setup_environment_variables(vllm_version: str):
@contextlib.contextmanager
-def build_llm_with_lmcache(lmcache_connector: str, model: str,
- vllm_version: str):
+def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
ktc = KVTransferConfig(
kv_connector=lmcache_connector,
kv_role="kv_both",
@@ -97,18 +96,19 @@ def print_output(
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
- print(f"Generation took {time.time() - start:.2f} seconds, "
- f"{req_str} request done.")
+ print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
print("-" * 50)
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("-v",
- "--version",
- choices=["v0", "v1"],
- default="v1",
- help="Specify vLLM version (default: v1)")
+ parser.add_argument(
+ "-v",
+ "--version",
+ choices=["v0", "v1"],
+ default="v1",
+ help="Specify vLLM version (default: v1)",
+ )
return parser.parse_args()
@@ -125,7 +125,6 @@ def main():
setup_environment_variables(args.version)
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
-
# This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts
shared_prompt = "Hello, how are you?" * 1000
@@ -136,9 +135,7 @@ def main():
shared_prompt + "Tell me a very long story",
]
- sampling_params = SamplingParams(temperature=0,
- top_p=0.95,
- max_tokens=10)
+ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
# Print the first output
print_output(llm, first_prompt, sampling_params, "first")
diff --git a/examples/lmcache/disagg_prefill_lmcache_v0.py b/examples/lmcache/disagg_prefill_lmcache_v0.py
index 66cc94185230..8a1676d53a05 100644
--- a/examples/lmcache/disagg_prefill_lmcache_v0.py
+++ b/examples/lmcache/disagg_prefill_lmcache_v0.py
@@ -4,12 +4,13 @@
with LMCache.
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and launch an additional LMCache server.
-KV cache is transferred in the following manner:
+KV cache is transferred in the following manner:
vLLM prefill node -> LMCache server -> vLLM decode node.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
+
import os
import subprocess
import time
@@ -17,7 +18,6 @@
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
-
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
@@ -49,19 +49,23 @@ def run_prefill(prefill_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
- ktc = KVTransferConfig(kv_connector="LMCacheConnector",
- kv_role="kv_producer",
- kv_rank=0,
- kv_parallel_size=2)
+ ktc = KVTransferConfig(
+ kv_connector="LMCacheConnector",
+ kv_role="kv_producer",
+ kv_rank=0,
+ kv_parallel_size=2,
+ )
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
- llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
- kv_transfer_config=ktc,
- max_model_len=8000,
- gpu_memory_utilization=0.8,
- enforce_eager=True)
-
- #llm.generate(prompts, sampling_params)
+ llm = LLM(
+ model="mistralai/Mistral-7B-Instruct-v0.2",
+ kv_transfer_config=ktc,
+ max_model_len=8000,
+ gpu_memory_utilization=0.8,
+ enforce_eager=True,
+ )
+
+ # llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
@@ -79,17 +83,21 @@ def run_decode(prefill_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
- ktc = KVTransferConfig(kv_connector="LMCacheConnector",
- kv_role="kv_consumer",
- kv_rank=1,
- kv_parallel_size=2)
+ ktc = KVTransferConfig(
+ kv_connector="LMCacheConnector",
+ kv_role="kv_consumer",
+ kv_rank=1,
+ kv_parallel_size=2,
+ )
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
- llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
- kv_transfer_config=ktc,
- max_model_len=8000,
- gpu_memory_utilization=0.8,
- enforce_eager=True)
+ llm = LLM(
+ model="mistralai/Mistral-7B-Instruct-v0.2",
+ kv_transfer_config=ktc,
+ max_model_len=8000,
+ gpu_memory_utilization=0.8,
+ enforce_eager=True,
+ )
print("Waiting for prefill node to finish...")
prefill_done.wait()
@@ -105,10 +113,9 @@ def run_decode(prefill_done, prompts, timeout=1):
def run_lmcache_server(port):
- server_proc = subprocess.Popen([
- "python", "-m", "lmcache.experimental.server", "localhost",
- str(port)
- ])
+ server_proc = subprocess.Popen(
+ ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
+ )
return server_proc
diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
index 8db93bc8931b..06a323fcb907 100644
--- a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
+++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
@@ -17,13 +17,17 @@ async def lifespan(app: FastAPI):
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
- prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
- decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
-
- app.state.prefill_client = httpx.AsyncClient(timeout=None,
- base_url=prefiller_base_url)
- app.state.decode_client = httpx.AsyncClient(timeout=None,
- base_url=decoder_base_url)
+ prefiller_base_url = (
+ f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
+ )
+ decoder_base_url = (
+ f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
+ )
+
+ app.state.prefill_client = httpx.AsyncClient(
+ timeout=None, base_url=prefiller_base_url
+ )
+ app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url)
yield
@@ -37,7 +41,6 @@ async def lifespan(app: FastAPI):
class StatsCalculator:
-
def __init__(self):
self._stats = []
self._last_log_time = time.time()
@@ -51,13 +54,18 @@ def add(self, value):
def _log_stats(self):
# Print average, median, and 99th percentile
np_arr = np.array(self._stats)
- output_str = f"\nNum requests: {len(self._stats)}" + \
- "\nPrefill node TTFT stats:" + \
- f"\n - Average (ms): {np.mean(np_arr)}" + \
- f"\n - Median (ms): {np.median(np_arr)}" + \
- f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
- print("===============================", output_str,
- "===============================")
+ output_str = (
+ f"\nNum requests: {len(self._stats)}"
+ + "\nPrefill node TTFT stats:"
+ + f"\n - Average (ms): {np.mean(np_arr)}"
+ + f"\n - Median (ms): {np.median(np_arr)}"
+ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
+ )
+ print(
+ "===============================",
+ output_str,
+ "===============================",
+ )
stats_calculator = StatsCalculator()
@@ -82,15 +90,16 @@ def parse_args():
app.state.decode_client = None
-async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
- req_data: dict):
+async def send_request_to_service(
+ client: httpx.AsyncClient, endpoint: str, req_data: dict
+):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
- req_data['max_tokens'] = 1
- if 'max_completion_tokens' in req_data:
- req_data['max_completion_tokens'] = 1
+ req_data["max_tokens"] = 1
+ if "max_completion_tokens" in req_data:
+ req_data["max_completion_tokens"] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
@@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
return response
-async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
- req_data: dict):
+async def stream_service_response(
+ client: httpx.AsyncClient, endpoint: str, req_data: dict
+):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
- async with client.stream("POST", endpoint, json=req_data,
- headers=headers) as response:
+ async with client.stream(
+ "POST", endpoint, json=req_data, headers=headers
+ ) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@@ -121,28 +132,28 @@ async def handle_completions(request: Request):
req_data = await request.json()
# Send request to prefill service, ignore the response
- await send_request_to_service(app.state.prefill_client, "/completions",
- req_data)
+ await send_request_to_service(
+ app.state.prefill_client, "/completions", req_data
+ )
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
- async for chunk in stream_service_response(app.state.decode_client,
- "/completions",
- req_data):
+ async for chunk in stream_service_response(
+ app.state.decode_client, "/completions", req_data
+ ):
yield chunk
- return StreamingResponse(generate_stream(),
- media_type="application/json")
+ return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import sys
import traceback
+
exc_info = sys.exc_info()
- print("Error occurred in disagg prefill proxy server"
- " - completions endpoint")
+ print("Error occurred in disagg prefill proxy server - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request):
req_data = await request.json()
# Send request to prefill service, ignore the response
- await send_request_to_service(app.state.prefill_client,
- "/chat/completions", req_data)
+ await send_request_to_service(
+ app.state.prefill_client, "/chat/completions", req_data
+ )
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
- async for chunk in stream_service_response(app.state.decode_client,
- "/chat/completions",
- req_data):
+ async for chunk in stream_service_response(
+ app.state.decode_client, "/chat/completions", req_data
+ ):
yield chunk
- return StreamingResponse(generate_stream(),
- media_type="application/json")
+ return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import sys
import traceback
+
exc_info = sys.exc_info()
- print("Error occurred in disagg prefill proxy server "
- " - chat completions endpoint")
+ print(
+ "Error occurred in disagg prefill proxy server - chat completions endpoint"
+ )
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
-if __name__ == '__main__':
+if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
+
uvicorn.run(app, host=global_args.host, port=global_args.port)
diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py
index 7748f8ca6133..0a81b8d3b5b8 100644
--- a/examples/lmcache/kv_cache_sharing_lmcache_v1.py
+++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py
@@ -3,13 +3,14 @@
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
-KV cache is transferred in the following manner:
+KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
+
import os
import subprocess
import time
@@ -17,7 +18,6 @@
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
-
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
@@ -49,15 +49,16 @@ def run_store(store_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
- ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
- kv_role="kv_both")
+ ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
- llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
- kv_transfer_config=ktc,
- max_model_len=8000,
- gpu_memory_utilization=0.8,
- enforce_eager=True)
+ llm = LLM(
+ model="mistralai/Mistral-7B-Instruct-v0.2",
+ kv_transfer_config=ktc,
+ max_model_len=8000,
+ gpu_memory_utilization=0.8,
+ enforce_eager=True,
+ )
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
@@ -76,15 +77,16 @@ def run_retrieve(store_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
- ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
- kv_role="kv_both")
+ ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
- llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
- kv_transfer_config=ktc,
- max_model_len=8000,
- gpu_memory_utilization=0.8,
- enforce_eager=True)
+ llm = LLM(
+ model="mistralai/Mistral-7B-Instruct-v0.2",
+ kv_transfer_config=ktc,
+ max_model_len=8000,
+ gpu_memory_utilization=0.8,
+ enforce_eager=True,
+ )
print("Waiting for KV cache store to finish...")
store_done.wait()
@@ -100,10 +102,9 @@ def run_retrieve(store_done, prompts, timeout=1):
def run_lmcache_server(port):
- server_proc = subprocess.Popen([
- "python", "-m", "lmcache.experimental.server", "localhost",
- str(port)
- ])
+ server_proc = subprocess.Popen(
+ ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
+ )
return server_proc
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index bab41c915c32..56cdd6861baa 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
"""
-This example shows how to use vLLM for running offline inference
+This example shows how to use vLLM for running offline inference
with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
+
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
@@ -22,7 +23,7 @@
question_per_audio_count = {
0: "What is 1+1?",
1: "What is recited in the audio?",
- 2: "What sport and what nursery rhyme are referenced?"
+ 2: "What sport and what nursery rhyme are referenced?",
}
@@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
- tokenizer = AutoTokenizer.from_pretrained(model_name,
- trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
@@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count},
)
- stop_tokens = ['<|im_end|>', '<|endoftext|>']
+ stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
audio_placeholder = "()" * audio_count
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
- messages = [{
- 'role': 'user',
- 'content': f'{audio_placeholder}\n{question}'
- }]
- prompt = tokenizer.apply_chat_template(messages,
- tokenize=False,
- add_generation_prompt=True,
- chat_template=audio_chat_template)
+ messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
+ prompt = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ chat_template=audio_chat_template,
+ )
return ModelRequestData(
engine_args=engine_args,
@@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora")
- placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
+ placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
@@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count},
)
- audio_in_prompt = "".join([
- f"Audio {idx+1}: "
- f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
- ])
+ audio_in_prompt = "".join(
+ [
+ f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
+ for idx in range(audio_count)
+ ]
+ )
- prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
- "<|im_start|>user\n"
- f"{audio_in_prompt}{question}<|im_end|>\n"
- "<|im_start|>assistant\n")
+ prompt = (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ f"{audio_in_prompt}{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
return ModelRequestData(
engine_args=engine_args,
@@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int):
limit_mm_per_prompt={"audio": audio_count},
)
- audio_in_prompt = "".join([
- "<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
- ])
+ audio_in_prompt = "".join(
+ ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
+ )
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
- "generating text and speech.")
+ "generating text and speech."
+ )
- prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
- "<|im_start|>user\n"
- f"{audio_in_prompt}{question}<|im_end|>\n"
- "<|im_start|>assistant\n")
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n"
+ f"{audio_in_prompt}{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
@@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
- messages = [{
- 'role': 'user',
- 'content': "<|audio|>\n" * audio_count + question
- }]
- prompt = tokenizer.apply_chat_template(messages,
- tokenize=False,
- add_generation_prompt=True)
+ messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
engine_args = EngineArgs(
model=model_name,
@@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
- assert audio_count == 1, (
- "Whisper only support single audio input per prompt")
+ assert audio_count == 1, "Whisper only support single audio input per prompt"
model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>"
@@ -252,27 +254,33 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
def parse_args():
parser = FlexibleArgumentParser(
- description='Demo on using vLLM for offline inference with '
- 'audio language models')
- parser.add_argument('--model-type',
- '-m',
- type=str,
- default="ultravox",
- choices=model_example_map.keys(),
- help='Huggingface "model_type".')
- parser.add_argument('--num-prompts',
- type=int,
- default=1,
- help='Number of prompts to run.')
- parser.add_argument("--num-audios",
- type=int,
- default=1,
- choices=[0, 1, 2],
- help="Number of audio items per prompt.")
- parser.add_argument("--seed",
- type=int,
- default=None,
- help="Set the seed when initializing `vllm.LLM`.")
+ description="Demo on using vLLM for offline inference with "
+ "audio language models"
+ )
+ parser.add_argument(
+ "--model-type",
+ "-m",
+ type=str,
+ default="ultravox",
+ choices=model_example_map.keys(),
+ help='Huggingface "model_type".',
+ )
+ parser.add_argument(
+ "--num-prompts", type=int, default=1, help="Number of prompts to run."
+ )
+ parser.add_argument(
+ "--num-audios",
+ type=int,
+ default=1,
+ choices=[0, 1, 2],
+ help="Number of audio items per prompt.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
return parser.parse_args()
@@ -283,29 +291,30 @@ def main(args):
raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios
- req_data = model_example_map[model](question_per_audio_count[audio_count],
- audio_count)
+ req_data = model_example_map[model](
+ question_per_audio_count[audio_count], audio_count
+ )
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
- req_data.engine_args.limit_mm_per_prompt or {})
+ req_data.engine_args.limit_mm_per_prompt or {}
+ )
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
- sampling_params = SamplingParams(temperature=0.2,
- max_tokens=64,
- stop_token_ids=req_data.stop_token_ids)
+ sampling_params = SamplingParams(
+ temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
+ )
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
- asset.audio_and_sample_rate
- for asset in audio_assets[:audio_count]
+ asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}
@@ -315,8 +324,9 @@ def main(args):
# Batch inference
inputs = [inputs] * args.num_prompts
# Add LoRA request if applicable
- lora_request = (req_data.lora_requests *
- args.num_prompts if req_data.lora_requests else None)
+ lora_request = (
+ req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
+ )
outputs = llm.generate(
inputs,
diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py
index 8e6f78ed7de2..b0bb5aa71b8a 100644
--- a/examples/offline_inference/basic/chat.py
+++ b/examples/offline_inference/basic/chat.py
@@ -56,22 +56,12 @@ def print_outputs(outputs):
# In this script, we demonstrate how to pass input to the chat method:
conversation = [
- {
- "role": "system",
- "content": "You are a helpful assistant"
- },
- {
- "role": "user",
- "content": "Hello"
- },
- {
- "role": "assistant",
- "content": "Hello! How can I assist you today?"
- },
+ {"role": "system", "content": "You are a helpful assistant"},
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
- "content":
- "Write an essay about the importance of higher education.",
+ "content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py
index 5b6dcb41eee1..40ccb1294e42 100644
--- a/examples/offline_inference/basic/classify.py
+++ b/examples/offline_inference/basic/classify.py
@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
- parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
- task="classify",
- enforce_eager=True)
+ parser.set_defaults(
+ model="jason9693/Qwen2.5-1.5B-apeach", task="classify", enforce_eager=True
+ )
return parser.parse_args()
@@ -36,10 +36,11 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs
- probs_trimmed = ((str(probs[:16])[:-1] +
- ", ...]") if len(probs) > 16 else probs)
- print(f"Prompt: {prompt!r} \n"
- f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
+ probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
+ print(
+ f"Prompt: {prompt!r} \n"
+ f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
+ )
print("-" * 60)
diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py
index cb5f923ffb69..38a73ccca251 100644
--- a/examples/offline_inference/basic/embed.py
+++ b/examples/offline_inference/basic/embed.py
@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
- parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
- task="embed",
- enforce_eager=True)
+ parser.set_defaults(
+ model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
+ )
return parser.parse_args()
@@ -36,10 +36,10 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
- embeds_trimmed = ((str(embeds[:16])[:-1] +
- ", ...]") if len(embeds) > 16 else embeds)
- print(f"Prompt: {prompt!r} \n"
- f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
+ embeds_trimmed = (
+ (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
+ )
+ print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print("-" * 60)
diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py
index d2bda8b3180c..3da73c6c407d 100644
--- a/examples/offline_inference/basic/score.py
+++ b/examples/offline_inference/basic/score.py
@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
- parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
- task="score",
- enforce_eager=True)
+ parser.set_defaults(
+ model="BAAI/bge-reranker-v2-m3", task="score", enforce_eager=True
+ )
return parser.parse_args()
diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py
index 6548857b6d11..c1edfb52ff70 100644
--- a/examples/offline_inference/batch_llm_inference.py
+++ b/examples/offline_inference/batch_llm_inference.py
@@ -17,12 +17,14 @@
Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html
"""
+
import ray
from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
-assert Version(ray.__version__) >= Version(
- "2.44.1"), "Ray version must be at least 2.44.1"
+assert Version(ray.__version__) >= Version("2.44.1"), (
+ "Ray version must be at least 2.44.1"
+)
# Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False)
@@ -53,20 +55,18 @@
vllm_processor = build_llm_processor(
config,
preprocess=lambda row: dict(
- messages=[{
- "role": "system",
- "content": "You are a bot that responds with haikus."
- }, {
- "role": "user",
- "content": row["text"]
- }],
+ messages=[
+ {"role": "system", "content": "You are a bot that responds with haikus."},
+ {"role": "user", "content": row["text"]},
+ ],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
- )),
+ ),
+ ),
postprocess=lambda row: dict(
answer=row["generated_text"],
- **row # This will return all the original columns in the dataset.
+ **row, # This will return all the original columns in the dataset.
),
)
diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py
index b532bf42adfb..61230d895584 100644
--- a/examples/offline_inference/chat_with_tools.py
+++ b/examples/offline_inference/chat_with_tools.py
@@ -50,87 +50,93 @@
# or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
-llm = LLM(model=model_name,
- tokenizer_mode="mistral",
- config_format="mistral",
- load_format="mistral")
+llm = LLM(
+ model=model_name,
+ tokenizer_mode="mistral",
+ config_format="mistral",
+ load_format="mistral",
+)
def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
- random_id = ''.join(random.choice(characters) for _ in range(length))
+ random_id = "".join(random.choice(characters) for _ in range(length))
return random_id
# simulate an API that can be called
-def get_current_weather(city: str, state: str, unit: 'str'):
- return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
- "partly cloudly, with highs in the 90's.")
+def get_current_weather(city: str, state: str, unit: "str"):
+ return (
+ f"The weather in {city}, {state} is 85 degrees {unit}. It is "
+ "partly cloudly, with highs in the 90's."
+ )
tool_functions = {"get_current_weather": get_current_weather}
-tools = [{
- "type": "function",
- "function": {
- "name": "get_current_weather",
- "description": "Get the current weather in a given location",
- "parameters": {
- "type": "object",
- "properties": {
- "city": {
- "type":
- "string",
- "description":
- "The city to find the weather for, e.g. 'San Francisco'"
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "The city to find the weather for, e.g. 'San Francisco'",
+ },
+ "state": {
+ "type": "string",
+ "description": "the two-letter abbreviation for the state that the city is"
+ " in, e.g. 'CA' which would mean 'California'",
+ },
+ "unit": {
+ "type": "string",
+ "description": "The unit to fetch the temperature in",
+ "enum": ["celsius", "fahrenheit"],
+ },
},
- "state": {
- "type":
- "string",
- "description":
- "the two-letter abbreviation for the state that the city is"
- " in, e.g. 'CA' which would mean 'California'"
- },
- "unit": {
- "type": "string",
- "description": "The unit to fetch the temperature in",
- "enum": ["celsius", "fahrenheit"]
- }
+ "required": ["city", "state", "unit"],
},
- "required": ["city", "state", "unit"]
- }
+ },
}
-}]
+]
-messages = [{
- "role":
- "user",
- "content":
- "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
-}]
+messages = [
+ {
+ "role": "user",
+ "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
+ }
+]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()
# append the assistant message
-messages.append({
- "role": "assistant",
- "content": output,
-})
+messages.append(
+ {
+ "role": "assistant",
+ "content": output,
+ }
+)
# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
- tool_functions[call['name']](**call['arguments']) for call in tool_calls
+ tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
]
# append the answer as a tool message and let the LLM give you an answer
-messages.append({
- "role": "tool",
- "content": "\n\n".join(tool_answers),
- "tool_call_id": generate_random_id(),
-})
+messages.append(
+ {
+ "role": "tool",
+ "content": "\n\n".join(tool_answers),
+ "tool_call_id": generate_random_id(),
+ }
+)
outputs = llm.chat(messages, sampling_params, tools=tools)
diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py
index f636a08c0b09..bf60d883c410 100644
--- a/examples/offline_inference/data_parallel.py
+++ b/examples/offline_inference/data_parallel.py
@@ -27,6 +27,7 @@
--master-addr=10.99.48.128 \
--master-port=13345
"""
+
import os
from time import sleep
@@ -36,46 +37,46 @@
def parse_args():
import argparse
+
parser = argparse.ArgumentParser(description="Data Parallel Inference")
- parser.add_argument("--model",
- type=str,
- default="ibm-research/PowerMoE-3b",
- help="Model name or path")
- parser.add_argument("--dp-size",
- type=int,
- default=2,
- help="Data parallel size")
- parser.add_argument("--tp-size",
- type=int,
- default=2,
- help="Tensor parallel size")
- parser.add_argument("--node-size",
- type=int,
- default=1,
- help="Total number of nodes")
- parser.add_argument("--node-rank",
- type=int,
- default=0,
- help="Rank of the current node")
- parser.add_argument("--master-addr",
- type=str,
- default="",
- help="Master node IP address")
- parser.add_argument("--master-port",
- type=int,
- default=0,
- help="Master node port")
- parser.add_argument("--enforce-eager",
- action='store_true',
- help="Enforce eager mode execution.")
- parser.add_argument("--trust-remote-code",
- action='store_true',
- help="Trust remote code.")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="ibm-research/PowerMoE-3b",
+ help="Model name or path",
+ )
+ parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
+ parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
+ parser.add_argument(
+ "--node-size", type=int, default=1, help="Total number of nodes"
+ )
+ parser.add_argument(
+ "--node-rank", type=int, default=0, help="Rank of the current node"
+ )
+ parser.add_argument(
+ "--master-addr", type=str, default="", help="Master node IP address"
+ )
+ parser.add_argument("--master-port", type=int, default=0, help="Master node port")
+ parser.add_argument(
+ "--enforce-eager", action="store_true", help="Enforce eager mode execution."
+ )
+ parser.add_argument(
+ "--trust-remote-code", action="store_true", help="Trust remote code."
+ )
return parser.parse_args()
-def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
- dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
+def main(
+ model,
+ dp_size,
+ local_dp_rank,
+ global_dp_rank,
+ dp_master_ip,
+ dp_master_port,
+ GPUs_per_dp_rank,
+ enforce_eager,
+ trust_remote_code,
+):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -110,9 +111,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
- sampling_params = SamplingParams(temperature=0.8,
- top_p=0.95,
- max_tokens=[16, 20][global_dp_rank % 2])
+ sampling_params = SamplingParams(
+ temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
+ )
# Create an LLM.
llm = LLM(
@@ -130,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
break
prompt = output.prompt
generated_text = output.outputs[0].text
- print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
- f"Generated text: {generated_text!r}")
+ print(
+ f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
+ f"Generated text: {generated_text!r}"
+ )
# Give engines time to pause their processing loops before exiting.
sleep(1)
if __name__ == "__main__":
-
args = parse_args()
dp_size = args.dp_size
@@ -160,20 +162,29 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
procs = []
for local_dp_rank, global_dp_rank in enumerate(
- range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
- proc = Process(target=main,
- args=(args.model, dp_size, local_dp_rank,
- global_dp_rank, dp_master_ip, dp_master_port,
- tp_size, args.enforce_eager,
- args.trust_remote_code))
+ range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
+ ):
+ proc = Process(
+ target=main,
+ args=(
+ args.model,
+ dp_size,
+ local_dp_rank,
+ global_dp_rank,
+ dp_master_ip,
+ dp_master_port,
+ tp_size,
+ args.enforce_eager,
+ args.trust_remote_code,
+ ),
+ )
proc.start()
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join(timeout=300)
if proc.exitcode is None:
- print(f"Killing process {proc.pid} that "
- f"didn't stop within 5 minutes.")
+ print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py
index 11918f72feec..13e7759f9953 100644
--- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py
+++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py
@@ -16,17 +16,18 @@
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
-llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
- enforce_eager=True,
- gpu_memory_utilization=0.8,
- max_num_batched_tokens=64,
- max_num_seqs=16,
- kv_transfer_config=KVTransferConfig(
- kv_connector="SharedStorageConnector",
- kv_role="kv_both",
- kv_connector_extra_config={
- "shared_storage_path": "local_storage"
- })) #, max_model_len=2048, max_num_batched_tokens=2048)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B-Instruct",
+ enforce_eager=True,
+ gpu_memory_utilization=0.8,
+ max_num_batched_tokens=64,
+ max_num_seqs=16,
+ kv_transfer_config=KVTransferConfig(
+ kv_connector="SharedStorageConnector",
+ kv_role="kv_both",
+ kv_connector_extra_config={"shared_storage_path": "local_storage"},
+ ),
+) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py
index 798128301e0f..603e67289840 100644
--- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py
+++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py
@@ -14,15 +14,16 @@
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
-llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
- enforce_eager=True,
- gpu_memory_utilization=0.8,
- kv_transfer_config=KVTransferConfig(
- kv_connector="SharedStorageConnector",
- kv_role="kv_both",
- kv_connector_extra_config={
- "shared_storage_path": "local_storage"
- })) #, max_model_len=2048, max_num_batched_tokens=2048)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B-Instruct",
+ enforce_eager=True,
+ gpu_memory_utilization=0.8,
+ kv_transfer_config=KVTransferConfig(
+ kv_connector="SharedStorageConnector",
+ kv_role="kv_both",
+ kv_connector_extra_config={"shared_storage_path": "local_storage"},
+ ),
+) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py
index bb6fdd48f79e..3ccab0dcd6d3 100644
--- a/examples/offline_inference/disaggregated_prefill.py
+++ b/examples/offline_inference/disaggregated_prefill.py
@@ -4,6 +4,7 @@
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and then transfer the KV cache between them.
"""
+
import os
import time
from multiprocessing import Event, Process
@@ -32,17 +33,21 @@ def run_prefill(prefill_done):
# This instance is the prefill node (kv_producer, rank 0).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector.
- ktc = KVTransferConfig(kv_connector="PyNcclConnector",
- kv_role="kv_producer",
- kv_rank=0,
- kv_parallel_size=2)
+ ktc = KVTransferConfig(
+ kv_connector="PyNcclConnector",
+ kv_role="kv_producer",
+ kv_rank=0,
+ kv_parallel_size=2,
+ )
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
- llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
- kv_transfer_config=ktc,
- max_model_len=2000,
- gpu_memory_utilization=0.8)
+ llm = LLM(
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ kv_transfer_config=ktc,
+ max_model_len=2000,
+ gpu_memory_utilization=0.8,
+ )
llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
@@ -72,17 +77,21 @@ def run_decode(prefill_done):
# This instance is the decode node (kv_consumer, rank 1).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector.
- ktc = KVTransferConfig(kv_connector="PyNcclConnector",
- kv_role="kv_consumer",
- kv_rank=1,
- kv_parallel_size=2)
+ ktc = KVTransferConfig(
+ kv_connector="PyNcclConnector",
+ kv_role="kv_consumer",
+ kv_rank=1,
+ kv_parallel_size=2,
+ )
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
- llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
- kv_transfer_config=ktc,
- max_model_len=2000,
- gpu_memory_utilization=0.8)
+ llm = LLM(
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ kv_transfer_config=ktc,
+ max_model_len=2000,
+ gpu_memory_utilization=0.8,
+ )
# Wait for the producer to start the pipe
print("Waiting for prefill node to finish...")
@@ -99,8 +108,8 @@ def run_decode(prefill_done):
def main():
prefill_done = Event()
- prefill_process = Process(target=run_prefill, args=(prefill_done, ))
- decode_process = Process(target=run_decode, args=(prefill_done, ))
+ prefill_process = Process(target=run_prefill, args=(prefill_done,))
+ decode_process = Process(target=run_decode, args=(prefill_done,))
# Start prefill node
prefill_process.start()
diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py
index 615f67e9f8d8..3dd9e5464641 100644
--- a/examples/offline_inference/eagle.py
+++ b/examples/offline_inference/eagle.py
@@ -20,9 +20,7 @@ def load_prompts(dataset_path, num_prompts):
print(f"Error reading dataset: {e}")
return []
else:
- prompts = [
- "The future of AI is", "The president of the United States is"
- ]
+ prompts = ["The future of AI is", "The president of the United States is"]
return prompts[:num_prompts]
@@ -33,34 +31,32 @@ def parse_args():
"--dataset",
type=str,
default="./examples/data/gsm8k.jsonl",
- help="downloaded from the eagle repo " \
- "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
+ help="downloaded from the eagle repo "
+ "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
+ )
+ parser.add_argument(
+ "--method", type=str, default="eagle", choices=["eagle", "eagle3"]
)
- parser.add_argument("--method",
- type=str,
- default='eagle',
- choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1)
- parser.add_argument("--enforce_eager", action='store_true')
- parser.add_argument("--enable_chunked_prefill", action='store_true')
+ parser.add_argument("--enforce_eager", action="store_true")
+ parser.add_argument("--enable_chunked_prefill", action="store_true")
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
return parser.parse_args()
def main():
-
args = parse_args()
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
- if args.method == 'eagle':
+ if args.method == "eagle":
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
- elif args.method == 'eagle3':
+ elif args.method == "eagle3":
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else:
raise ValueError(f"unknown method: {args.method}")
@@ -72,11 +68,9 @@ def main():
prompts = load_prompts(args.dataset, args.num_prompts)
prompt_ids = [
- tokenizer.apply_chat_template([{
- "role": "user",
- "content": prompt
- }],
- add_generation_prompt=True)
+ tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}], add_generation_prompt=True
+ )
for prompt in prompts
]
@@ -102,8 +96,7 @@ def main():
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
- outputs = llm.generate(prompt_token_ids=prompt_ids,
- sampling_params=sampling_params)
+ outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
# print the generated text
for output in outputs:
@@ -120,19 +113,22 @@ def main():
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
- for step, count in enumerate(
- output.metrics.spec_token_acceptance_counts):
+ for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
print("-" * 50)
- print(f"mean acceptance length (including bonus tokens): \
- {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}")
+ print(
+ f"mean acceptance length (including bonus tokens): \
+ {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
+ )
print("-" * 50)
# print acceptance at each token position
for i in range(len(acceptance_counts)):
- print(f"acceptance at token {i}:"
- f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")
+ print(
+ f"acceptance at token {i}:"
+ f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
+ )
if __name__ == "__main__":
diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py
index b347ddbf3197..23f60c431fc2 100644
--- a/examples/offline_inference/embed_jina_embeddings_v3.py
+++ b/examples/offline_inference/embed_jina_embeddings_v3.py
@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
- parser.set_defaults(model="jinaai/jina-embeddings-v3",
- task="embed",
- trust_remote_code=True)
+ parser.set_defaults(
+ model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
+ )
return parser.parse_args()
@@ -41,11 +41,14 @@ def main(args: Namespace):
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
- embeds_trimmed = ((str(embeds[:16])[:-1] +
- ", ...]") if len(embeds) > 16 else embeds)
- print(f"Prompt: {prompt!r} \n"
- f"Embeddings for text matching: {embeds_trimmed} "
- f"(size={len(embeds)})")
+ embeds_trimmed = (
+ (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
+ )
+ print(
+ f"Prompt: {prompt!r} \n"
+ f"Embeddings for text matching: {embeds_trimmed} "
+ f"(size={len(embeds)})"
+ )
print("-" * 60)
diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py
index 7a6cb02556d9..59c0592ae9e2 100644
--- a/examples/offline_inference/embed_matryoshka_fy.py
+++ b/examples/offline_inference/embed_matryoshka_fy.py
@@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
- parser.set_defaults(model="jinaai/jina-embeddings-v3",
- task="embed",
- trust_remote_code=True)
+ parser.set_defaults(
+ model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
+ )
return parser.parse_args()
@@ -39,11 +39,10 @@ def main(args: Namespace):
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
- embeds_trimmed = ((str(embeds[:16])[:-1] +
- ", ...]") if len(embeds) > 16 else embeds)
- print(f"Prompt: {prompt!r} \n"
- f"Embeddings: {embeds_trimmed} "
- f"(size={len(embeds)})")
+ embeds_trimmed = (
+ (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
+ )
+ print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print("-" * 60)
diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py
index c4916e00f473..83dd1f667eb5 100644
--- a/examples/offline_inference/encoder_decoder.py
+++ b/examples/offline_inference/encoder_decoder.py
@@ -1,12 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
-'''
+"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
-'''
+"""
from vllm import LLM, SamplingParams
-from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
- TokensPrompt, zip_enc_dec_prompts)
+from vllm.inputs import (
+ ExplicitEncoderDecoderPrompt,
+ TextPrompt,
+ TokensPrompt,
+ zip_enc_dec_prompts,
+)
def create_prompts(tokenizer):
@@ -18,8 +22,9 @@ def create_prompts(tokenizer):
# - Helpers for building prompts
text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is")
- tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
- prompt="The capital of France is"))
+ tokens_prompt = TokensPrompt(
+ prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
+ )
# - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None
@@ -57,14 +62,19 @@ def create_prompts(tokenizer):
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompts(
- ['An encoder prompt', 'Another encoder prompt'],
- ['A decoder prompt', 'Another decoder prompt'])
+ ["An encoder prompt", "Another encoder prompt"],
+ ["A decoder prompt", "Another decoder prompt"],
+ )
# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
return [
- single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
- enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
+ single_text_prompt_raw,
+ single_text_prompt,
+ single_tokens_prompt,
+ enc_dec_prompt1,
+ enc_dec_prompt2,
+ enc_dec_prompt3,
] + zipped_prompt_list
@@ -85,10 +95,12 @@ def print_outputs(outputs):
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
- print(f"Output {i+1}:")
- print(f"Encoder prompt: {encoder_prompt!r}\n"
- f"Decoder prompt: {prompt!r}\n"
- f"Generated text: {generated_text!r}")
+ print(f"Output {i + 1}:")
+ print(
+ f"Encoder prompt: {encoder_prompt!r}\n"
+ f"Decoder prompt: {prompt!r}\n"
+ f"Generated text: {generated_text!r}"
+ )
print("-" * 50)
diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py
index 2883c37ca236..ae3737e37594 100644
--- a/examples/offline_inference/encoder_decoder_multimodal.py
+++ b/examples/offline_inference/encoder_decoder_multimodal.py
@@ -3,6 +3,7 @@
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
+
import time
from collections.abc import Sequence
from dataclasses import asdict
@@ -30,18 +31,14 @@ def run_florence2():
)
prompts = [
- { # implicit prompt with task token
+ { # implicit prompt with task token
"prompt": "",
- "multi_modal_data": {
- "image": ImageAsset("stop_sign").pil_image
- },
+ "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
},
- { # explicit encoder/decoder prompt
+ { # explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.",
- "multi_modal_data": {
- "image": ImageAsset("cherry_blossom").pil_image
- },
+ "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
},
"decoder_prompt": "",
},
@@ -63,20 +60,20 @@ def run_mllama():
)
prompts = [
- { # Implicit prompt
- "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
+ { # Implicit prompt
+ "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image,
},
},
- { # Explicit prompt
+ { # Explicit prompt
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image,
},
},
- "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
+ "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
},
]
@@ -96,13 +93,13 @@ def run_whisper():
)
prompts = [
- { # Test implicit prompt
+ { # Test implicit prompt
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
- { # Test explicit encoder/decoder prompt
+ { # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
@@ -110,7 +107,7 @@ def run_whisper():
},
},
"decoder_prompt": "<|startoftranscript|>",
- }
+ },
]
return ModelRequestData(
@@ -128,18 +125,23 @@ def run_whisper():
def parse_args():
parser = FlexibleArgumentParser(
- description='Demo on using vLLM for offline inference with '
- 'vision language models for text generation')
- parser.add_argument('--model-type',
- '-m',
- type=str,
- default="mllama",
- choices=model_example_map.keys(),
- help='Huggingface "model_type".')
- parser.add_argument("--seed",
- type=int,
- default=None,
- help="Set the seed when initializing `vllm.LLM`.")
+ description="Demo on using vLLM for offline inference with "
+ "vision language models for text generation"
+ )
+ parser.add_argument(
+ "--model-type",
+ "-m",
+ type=str,
+ default="mllama",
+ choices=model_example_map.keys(),
+ help='Huggingface "model_type".',
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
return parser.parse_args()
@@ -153,7 +155,8 @@ def main(args):
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
- req_data.engine_args.limit_mm_per_prompt or {})
+ req_data.engine_args.limit_mm_per_prompt or {}
+ )
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
@@ -179,8 +182,7 @@ def main(args):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
- print(f"Decoder prompt: {prompt!r}, "
- f"Generated text: {generated_text!r}")
+ print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
duration = time.time() - start
diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py
index d84cd9ee9f52..5d5e55a83d22 100644
--- a/examples/offline_inference/llm_engine_example.py
+++ b/examples/offline_inference/llm_engine_example.py
@@ -3,6 +3,7 @@
This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
+
import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
@@ -12,24 +13,26 @@
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return [
- ("A robot may not injure a human being",
- SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
- ("To be or not to be,",
- SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
- ("What is the meaning of life?",
- SamplingParams(n=2,
- temperature=0.8,
- top_p=0.95,
- frequency_penalty=0.1)),
+ (
+ "A robot may not injure a human being",
+ SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
+ ),
+ (
+ "To be or not to be,",
+ SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
+ ),
+ (
+ "What is the meaning of life?",
+ SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
+ ),
]
-def process_requests(engine: LLMEngine,
- test_prompts: list[tuple[str, SamplingParams]]):
+def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
- print('-' * 50)
+ print("-" * 50)
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
@@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine,
for request_output in request_outputs:
if request_output.finished:
print(request_output)
- print('-' * 50)
+ print("-" * 50)
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
@@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine:
def parse_args():
parser = FlexibleArgumentParser(
- description='Demo on using the LLMEngine class directly')
+ description="Demo on using the LLMEngine class directly"
+ )
parser = EngineArgs.add_cli_args(parser)
return parser.parse_args()
@@ -64,6 +68,6 @@ def main(args: argparse.Namespace):
process_requests(engine, test_prompts)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
main(args)
diff --git a/examples/offline_inference/load_sharded_state.py b/examples/offline_inference/load_sharded_state.py
index 7e90d5d25e29..5bb2327a3f83 100644
--- a/examples/offline_inference/load_sharded_state.py
+++ b/examples/offline_inference/load_sharded_state.py
@@ -36,22 +36,21 @@ def parse_args():
parser.set_defaults(load_format="sharded_state")
# Add validation arguments
- parser.add_argument("--prompt",
- type=str,
- default="Hello, world!",
- help="Prompt for validation")
- parser.add_argument("--max-tokens",
- type=int,
- default=100,
- help="Maximum number of tokens to generate")
- parser.add_argument("--temperature",
- type=float,
- default=0.7,
- help="Sampling temperature")
- parser.add_argument("--top-p",
- type=float,
- default=1.0,
- help="Top-p sampling parameter")
+ parser.add_argument(
+ "--prompt", type=str, default="Hello, world!", help="Prompt for validation"
+ )
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=100,
+ help="Maximum number of tokens to generate",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.7, help="Sampling temperature"
+ )
+ parser.add_argument(
+ "--top-p", type=float, default=1.0, help="Top-p sampling parameter"
+ )
return parser.parse_args()
@@ -60,8 +59,9 @@ def main():
args = parse_args()
engine_args = EngineArgs.from_cli_args(args)
- print(f"Loading model from {engine_args.model} "
- f"using format {engine_args.load_format}")
+ print(
+ f"Loading model from {engine_args.model} using format {engine_args.load_format}"
+ )
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
# Load the model using engine args
@@ -90,4 +90,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py
index b6608ec6e958..33c660015ba7 100644
--- a/examples/offline_inference/lora_with_quantization_inference.py
+++ b/examples/offline_inference/lora_with_quantization_inference.py
@@ -17,50 +17,55 @@
def create_test_prompts(
- lora_path: str
+ lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [
# this is an example of using quantization without LoRA
- ("My name is",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128), None),
+ (
+ "My name is",
+ SamplingParams(
+ temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
+ ),
+ None,
+ ),
# the next three examples use quantization with LoRA
- ("my name is",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128),
- LoRARequest("lora-test-1", 1, lora_path)),
- ("The capital of USA is",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128),
- LoRARequest("lora-test-2", 1, lora_path)),
- ("The capital of France is",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128),
- LoRARequest("lora-test-3", 1, lora_path)),
+ (
+ "my name is",
+ SamplingParams(
+ temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
+ ),
+ LoRARequest("lora-test-1", 1, lora_path),
+ ),
+ (
+ "The capital of USA is",
+ SamplingParams(
+ temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
+ ),
+ LoRARequest("lora-test-2", 1, lora_path),
+ ),
+ (
+ "The capital of France is",
+ SamplingParams(
+ temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
+ ),
+ LoRARequest("lora-test-3", 1, lora_path),
+ ),
]
-def process_requests(engine: LLMEngine,
- test_prompts: list[tuple[str, SamplingParams,
- Optional[LoRARequest]]]):
+def process_requests(
+ engine: LLMEngine,
+ test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
+):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
- engine.add_request(str(request_id),
- prompt,
- sampling_params,
- lora_request=lora_request)
+ engine.add_request(
+ str(request_id), prompt, sampling_params, lora_request=lora_request
+ )
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
@@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
print(f"Output: {request_output.outputs[0].text}")
-def initialize_engine(model: str, quantization: str,
- lora_repo: Optional[str]) -> LLMEngine:
+def initialize_engine(
+ model: str, quantization: str, lora_repo: Optional[str]
+) -> LLMEngine:
"""Initialize the LLMEngine."""
- engine_args = EngineArgs(model=model,
- quantization=quantization,
- enable_lora=True,
- max_lora_rank=64,
- max_loras=4)
+ engine_args = EngineArgs(
+ model=model,
+ quantization=quantization,
+ enable_lora=True,
+ max_lora_rank=64,
+ max_loras=4,
+ )
return LLMEngine.from_engine_args(engine_args)
@@ -90,32 +98,30 @@ def main():
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example",
- 'model': "huggyllama/llama-7b",
- 'quantization': "bitsandbytes",
- 'lora_repo': 'timdettmers/qlora-flan-7b'
+ "model": "huggyllama/llama-7b",
+ "quantization": "bitsandbytes",
+ "lora_repo": "timdettmers/qlora-flan-7b",
},
{
"name": "AWQ_inference_with_lora_example",
- 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
- 'quantization': "awq",
- 'lora_repo': 'jashing/tinyllama-colorist-lora'
+ "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
+ "quantization": "awq",
+ "lora_repo": "jashing/tinyllama-colorist-lora",
},
{
"name": "GPTQ_inference_with_lora_example",
- 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
- 'quantization': "gptq",
- 'lora_repo': 'jashing/tinyllama-colorist-lora'
- }
+ "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
+ "quantization": "gptq",
+ "lora_repo": "jashing/tinyllama-colorist-lora",
+ },
]
for test_config in test_configs:
- print(
- f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
+ print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
+ engine = initialize_engine(
+ test_config["model"], test_config["quantization"], test_config["lora_repo"]
)
- engine = initialize_engine(test_config['model'],
- test_config['quantization'],
- test_config['lora_repo'])
- lora_path = snapshot_download(repo_id=test_config['lora_repo'])
+ lora_path = snapshot_download(repo_id=test_config["lora_repo"])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
@@ -125,5 +131,5 @@ def main():
torch.cuda.empty_cache()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py
index 37c3181dc5fa..98fef2648f6b 100644
--- a/examples/offline_inference/mistral-small.py
+++ b/examples/offline_inference/mistral-small.py
@@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace):
messages = [
{
- "role":
- "user",
+ "role": "user",
"content": [
- {
- "type": "text",
- "text": prompt
- },
- {
- "type": "image_url",
- "image_url": {
- "url": image_url
- }
- },
+ {"type": "text", "text": prompt},
+ {"type": "image_url", "image_url": {"url": image_url}},
],
},
]
@@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace):
messages = [
{
- "role":
- "user",
+ "role": "user",
"content": [
- {
- "type": "text",
- "text": prompt
- },
- {
- "type": "image_url",
- "image_url": {
- "url": url_1
- }
- },
- {
- "type": "image_url",
- "image_url": {
- "url": url_2
- }
- },
+ {"type": "text", "text": prompt},
+ {"type": "image_url", "image_url": {"url": url_1}},
+ {"type": "image_url", "image_url": {"url": url_2}},
],
},
{
@@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace):
{
"role": "user",
"content": [
- {
- "type": "image_url",
- "image_url": {
- "url": url_3
- }
- },
+ {"type": "image_url", "image_url": {"url": url_3}},
],
},
]
@@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace):
def parse_args():
parser = argparse.ArgumentParser(
- description="Run a demo in simple or advanced mode.")
+ description="Run a demo in simple or advanced mode."
+ )
parser.add_argument(
"mode",
@@ -179,15 +152,18 @@ def parse_args():
help="Specify the demo mode: 'simple' or 'advanced'",
)
- parser.add_argument('--format',
- choices=["mistral", "hf"],
- default="mistral",
- help='Specify the format of the model to load.')
+ parser.add_argument(
+ "--format",
+ choices=["mistral", "hf"],
+ default="mistral",
+ help="Specify the format of the model to load.",
+ )
parser.add_argument(
- '--disable-mm-preprocessor-cache',
- action='store_true',
- help='If True, disables caching of multi-modal preprocessor/mapper.')
+ "--disable-mm-preprocessor-cache",
+ action="store_true",
+ help="If True, disables caching of multi-modal preprocessor/mapper.",
+ )
return parser.parse_args()
diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py
index 53c58a76d9dc..b750397f45b8 100644
--- a/examples/offline_inference/mlpspeculator.py
+++ b/examples/offline_inference/mlpspeculator.py
@@ -13,8 +13,9 @@
from vllm import LLM, SamplingParams
-def time_generation(llm: LLM, prompts: list[str],
- sampling_params: SamplingParams, title: str):
+def time_generation(
+ llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
+):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
@@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str],
end = time.time()
print("-" * 50)
print(title)
- print("time: ",
- (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
+ print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
@@ -38,7 +38,8 @@ def main():
template = (
"Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
- "\n\n### Response:\n")
+ "\n\n### Response:\n"
+ )
# Sample prompts.
prompts = [
diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py
index de409740292a..1fa2f16f82a8 100644
--- a/examples/offline_inference/multilora_inference.py
+++ b/examples/offline_inference/multilora_inference.py
@@ -15,7 +15,7 @@
def create_test_prompts(
- lora_path: str
+ lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.
@@ -26,38 +26,49 @@ def create_test_prompts(
first adapter have finished.
"""
return [
- ("A robot may not injure a human being",
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128), None),
- ("To be or not to be,",
- SamplingParams(temperature=0.8,
- top_k=5,
- presence_penalty=0.2,
- max_tokens=128), None),
+ (
+ "A robot may not injure a human being",
+ SamplingParams(
+ temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
+ ),
+ None,
+ ),
+ (
+ "To be or not to be,",
+ SamplingParams(
+ temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
+ ),
+ None,
+ ),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora", 1, lora_path)),
+ SamplingParams(
+ temperature=0.0,
+ logprobs=1,
+ prompt_logprobs=1,
+ max_tokens=128,
+ stop_token_ids=[32003],
+ ),
+ LoRARequest("sql-lora", 1, lora_path),
+ ),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
- SamplingParams(temperature=0.0,
- logprobs=1,
- prompt_logprobs=1,
- max_tokens=128,
- stop_token_ids=[32003]),
- LoRARequest("sql-lora2", 2, lora_path)),
+ SamplingParams(
+ temperature=0.0,
+ logprobs=1,
+ prompt_logprobs=1,
+ max_tokens=128,
+ stop_token_ids=[32003],
+ ),
+ LoRARequest("sql-lora2", 2, lora_path),
+ ),
]
-def process_requests(engine: LLMEngine,
- test_prompts: list[tuple[str, SamplingParams,
- Optional[LoRARequest]]]):
+def process_requests(
+ engine: LLMEngine,
+ test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
+):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
@@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine,
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
- engine.add_request(str(request_id),
- prompt,
- sampling_params,
- lora_request=lora_request)
+ engine.add_request(
+ str(request_id), prompt, sampling_params, lora_request=lora_request
+ )
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
@@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine:
# numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache.
- engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
- enable_lora=True,
- max_loras=1,
- max_lora_rank=8,
- max_cpu_loras=2,
- max_num_seqs=256)
+ engine_args = EngineArgs(
+ model="meta-llama/Llama-2-7b-hf",
+ enable_lora=True,
+ max_loras=1,
+ max_lora_rank=8,
+ max_cpu_loras=2,
+ max_num_seqs=256,
+ )
return LLMEngine.from_engine_args(engine_args)
@@ -105,5 +117,5 @@ def main():
process_requests(engine, test_prompts)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py
index 5906c7b2c6b3..f2d7698f22d7 100644
--- a/examples/offline_inference/neuron.py
+++ b/examples/offline_inference/neuron.py
@@ -30,7 +30,8 @@ def main():
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
- tensor_parallel_size=2)
+ tensor_parallel_size=2,
+ )
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py
index 4f63f1a2fb3c..a51caa2aec8b 100644
--- a/examples/offline_inference/neuron_eagle.py
+++ b/examples/offline_inference/neuron_eagle.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
-This example shows how to run offline inference with an EAGLE speculative
+This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
@@ -24,7 +24,7 @@
speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5,
- "max_model_len": 2048
+ "max_model_len": 2048,
},
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as
@@ -40,7 +40,7 @@
tensor_parallel_size=32,
override_neuron_config={
"enable_eagle_speculation": True,
- "enable_fused_speculation": True
+ "enable_fused_speculation": True,
},
)
diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py
index af21274a3a5b..ec38525b9daf 100644
--- a/examples/offline_inference/neuron_int8_quantization.py
+++ b/examples/offline_inference/neuron_int8_quantization.py
@@ -5,12 +5,12 @@
from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets.
-os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
+os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
-os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
+os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
# Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype.
-os.environ['NEURON_QUANT_DTYPE'] = "s8"
+os.environ["NEURON_QUANT_DTYPE"] = "s8"
# Sample prompts.
prompts = [
@@ -44,7 +44,8 @@ def main():
override_neuron_config={
"cast_logits_dtype": "bfloat16",
},
- tensor_parallel_size=2)
+ tensor_parallel_size=2,
+ )
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py
index bef434bae5ba..ecacbab771c2 100644
--- a/examples/offline_inference/neuron_speculation.py
+++ b/examples/offline_inference/neuron_speculation.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
-This example shows how to run offline inference with a speculative
+This example shows how to run offline inference with a speculative
decoding model on neuron.
"""
@@ -19,9 +19,9 @@
def config_buckets():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
- os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
+ os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
- os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
+ os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
def initialize_model():
@@ -31,7 +31,7 @@ def initialize_model():
speculative_config={
"model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4,
- "max_model_len": 2048
+ "max_model_len": 2048,
},
max_num_seqs=4,
max_model_len=2048,
@@ -60,5 +60,5 @@ def main():
process_requests(model, sampling_params)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/prefix_caching.py b/examples/offline_inference/prefix_caching.py
index f0bec387d3a9..d3dad24956a6 100644
--- a/examples/offline_inference/prefix_caching.py
+++ b/examples/offline_inference/prefix_caching.py
@@ -16,7 +16,8 @@
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
- "the following paragraph: ")
+ "the following paragraph: "
+)
# Sample prompts.
prompts = [
@@ -58,9 +59,11 @@ def main():
cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled.
- prefix_cached_llm = LLM(model="facebook/opt-125m",
- enable_prefix_caching=True,
- gpu_memory_utilization=0.4)
+ prefix_cached_llm = LLM(
+ model="facebook/opt-125m",
+ enable_prefix_caching=True,
+ gpu_memory_utilization=0.4,
+ )
# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
@@ -81,10 +84,12 @@ def main():
print("-" * 50)
# Compare the results and display the speedup
- generated_same = all([
- regular_generated_texts[i] == cached_generated_texts[i]
- for i in range(len(prompts))
- ])
+ generated_same = all(
+ [
+ regular_generated_texts[i] == cached_generated_texts[i]
+ for i in range(len(prompts))
+ ]
+ )
print(f"Generated answers are the same: {generated_same}")
diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py
index f97a1f32e621..586ec28b7a4c 100644
--- a/examples/offline_inference/prithvi_geospatial_mae.py
+++ b/examples/offline_inference/prithvi_geospatial_mae.py
@@ -16,7 +16,8 @@
Run the example:
python prithvi_geospatial_mae.py
-""" # noqa: E501
+""" # noqa: E501
+
import argparse
import datetime
import os
@@ -110,77 +111,67 @@
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
-with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
- 'w') as config_file:
+with open(
+ os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
+) as config_file:
config_file.write(model_config)
datamodule_config = {
- 'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
- 'batch_size':
- 16,
- 'constant_scale':
- 0.0001,
- 'data_root':
- '/dccstor/geofm-finetuning/datasets/sen1floods11',
- 'drop_last':
- True,
- 'no_data_replace':
- 0.0,
- 'no_label_replace':
- -1,
- 'num_workers':
- 8,
- 'test_transform': [
- albumentations.Resize(always_apply=False,
- height=448,
- interpolation=1,
- p=1,
- width=448),
- albumentations.pytorch.ToTensorV2(transpose_mask=False,
- always_apply=True,
- p=1.0)
+ "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
+ "batch_size": 16,
+ "constant_scale": 0.0001,
+ "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
+ "drop_last": True,
+ "no_data_replace": 0.0,
+ "no_label_replace": -1,
+ "num_workers": 8,
+ "test_transform": [
+ albumentations.Resize(
+ always_apply=False, height=448, interpolation=1, p=1, width=448
+ ),
+ albumentations.pytorch.ToTensorV2(
+ transpose_mask=False, always_apply=True, p=1.0
+ ),
],
}
class PrithviMAE:
-
def __init__(self):
print("Initializing PrithviMAE model")
- self.model = LLM(model=os.path.join(os.path.dirname(__file__),
- "./model"),
- skip_tokenizer_init=True,
- dtype="float32")
+ self.model = LLM(
+ model=os.path.join(os.path.dirname(__file__), "./model"),
+ skip_tokenizer_init=True,
+ dtype="float32",
+ )
def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure
mm_data = {
- "pixel_values":
- torch.empty(0) if input_data is None else input_data,
- "location_coords":
- torch.empty(0) if location_coords is None else location_coords
+ "pixel_values": torch.empty(0) if input_data is None else input_data,
+ "location_coords": torch.empty(0)
+ if location_coords is None
+ else location_coords,
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
- print(
- "################ Inference done (it took seconds) ##############"
- )
+ print("################ Inference done (it took seconds) ##############")
return outputs[0].outputs.data
def generate_datamodule():
datamodule = Sen1Floods11NonGeoDataModule(
- data_root=datamodule_config['data_root'],
+ data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"],
- test_transform=datamodule_config["test_transform"
- ""])
+ test_transform=datamodule_config["test_transform"],
+ )
return datamodule
@@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels):
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
min_value = OFFSET
- orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
- 1)
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
# No data as zeros
orig_img[~valid_mask] = 0
@@ -300,18 +290,21 @@ def load_example(
location_coords.append(coords)
try:
- match = re.search(r'(\d{7,8}T\d{6})', file)
+ match = re.search(r"(\d{7,8}T\d{6})", file)
if match:
year = int(match.group(1)[:4])
- julian_day = match.group(1).split('T')[0][4:]
+ julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
- julian_day = datetime.datetime.strptime(
- julian_day, '%m%d').timetuple().tm_yday
+ julian_day = (
+ datetime.datetime.strptime(julian_day, "%m%d")
+ .timetuple()
+ .tm_yday
+ )
temporal_coords.append([year, julian_day])
except Exception as e:
- print(f'Could not extract timestamp for {file} ({e})')
+ print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
@@ -320,50 +313,44 @@ def load_example(
return imgs, temporal_coords, location_coords, metas
-def run_model(input_data,
- temporal_coords,
- location_coords,
- model,
- datamodule,
- img_size,
- lightning_model=None):
+def run_model(
+ input_data,
+ temporal_coords,
+ location_coords,
+ model,
+ datamodule,
+ img_size,
+ lightning_model=None,
+):
# Reflect pad if not divisible by img_size
original_h, original_w = input_data.shape[-2:]
pad_h = (img_size - (original_h % img_size)) % img_size
pad_w = (img_size - (original_w % img_size)) % img_size
- input_data = np.pad(input_data,
- ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
- mode="reflect")
+ input_data = np.pad(
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
+ )
# Build sliding window
batch_size = 1
batch = torch.tensor(input_data, device="cpu")
- windows = (batch.unfold(3, img_size,
- img_size).unfold(4, img_size, img_size))
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5]
- windows = rearrange(windows,
- "b c t h1 w1 h w -> (b h1 w1) c t h w",
- h=img_size,
- w=img_size)
+ windows = rearrange(
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
+ )
# Split into batches if number of windows > batch_size
- num_batches = windows.shape[0] // batch_size if windows.shape[
- 0] > batch_size else 1
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)
- if torch.cuda.is_available():
- device = torch.device('cuda')
- else:
- device = torch.device('cpu')
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if temporal_coords:
- temporal_coords = torch.tensor(temporal_coords,
- device=device).unsqueeze(0)
+ temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
- location_coords = torch.tensor(location_coords[0],
- device=device).unsqueeze(0)
+ location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
else:
location_coords = None
@@ -371,26 +358,24 @@ def run_model(input_data,
pred_imgs = []
for x in windows:
# Apply standardization
- x = datamodule.test_transform(
- image=x.squeeze().numpy().transpose(1, 2, 0))
- x = datamodule.aug(x)['image']
+ x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
+ x = datamodule.aug(x)["image"]
with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
- x,
- temporal_coords=temporal_coords,
- location_coords=location_coords)
+ x, temporal_coords=temporal_coords, location_coords=location_coords
+ )
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1)
- y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
- size=img_size,
- mode="nearest")
+ y_hat = torch.nn.functional.interpolate(
+ y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
+ )
pred_imgs.append(y_hat)
@@ -437,8 +422,7 @@ def parse_args():
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
- help=
- "0-based indices of the six Prithvi channels to be selected from the "
+ help="0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
@@ -478,17 +462,18 @@ def main(
# Running model ------------------------------------------------------------
channels = [
- datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
+ datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB
- pred = run_model(input_data, temporal_coords, location_coords, model_obj,
- datamodule, img_size)
+ pred = run_model(
+ input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
+ )
# Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join(
- output_dir,
- f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
+ output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
+ )
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
# Save image + pred
@@ -502,13 +487,13 @@ def main(
channels=channels,
)
- pred[pred == 0.] = np.nan
+ pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
img_pred_file = os.path.join(
- output_dir,
- f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
+ output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
+ )
save_geotiff(
image=_convert_np_uint8(img_pred),
output_path=img_pred_file,
@@ -518,8 +503,9 @@ def main(
# Save image rgb
if rgb_outputs:
rgb_file = os.path.join(
- output_dir, "original_rgb_"
- f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
+ output_dir,
+ f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
+ )
save_geotiff(
image=_convert_np_uint8(rgb_orig),
output_path=rgb_file,
@@ -528,7 +514,6 @@ def main(
if __name__ == "__main__":
-
args = parse_args()
main(**vars(args))
diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py
index 3cf0c340d670..244a64b891c9 100644
--- a/examples/offline_inference/profiling.py
+++ b/examples/offline_inference/profiling.py
@@ -44,14 +44,17 @@ def get_dtype(dtype: str):
OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
-def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
- -> OutputLen_NumReqs_Map:
+
+
+def compute_request_output_lengths(
+ batch_size: int, step_requests: list[int]
+) -> OutputLen_NumReqs_Map:
"""
Given the number of requests, batch_size, and the number of requests
that each engine-step should process, step_requests, determine the
output lengths of the requests such that step_request is honoured.
- Example:
+ Example:
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
then return,
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
@@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
output_length -= 1
# sanity checks.
- assert sum(ol_nr.values()) == batch_size, \
- ("Number of requests in output-length assignment does not match "
- f"batch-size.\n batch size {batch_size} - "
- f"step requests {step_requests} - assignments {ol_nr}")
+ assert sum(ol_nr.values()) == batch_size, (
+ "Number of requests in output-length assignment does not match "
+ f"batch-size.\n batch size {batch_size} - "
+ f"step requests {step_requests} - assignments {ol_nr}"
+ )
# Check that the output-length is in [1, num-steps]. Output length must be
# at least 1 as all requests must participate in the prefill-step.
- assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
- ("Output lengths of requests should be in range "
- f"[1, num-engine-steps].\n batch size {batch_size} - "
- f"step requests {step_requests} - assignments {ol_nr}")
+ assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), (
+ "Output lengths of requests should be in range "
+ f"[1, num-engine-steps].\n batch size {batch_size} - "
+ f"step requests {step_requests} - assignments {ol_nr}"
+ )
return ol_nr
@@ -131,7 +136,7 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
context: ProfileContext object.
Returns:
- list[int]: Number of requests to process for all engine-steps.
+ list[int]: Number of requests to process for all engine-steps.
output[i], contains the number of requests that the ith step
should process.
"""
@@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
# that their output lengths must be equal to num_engine_steps.
return [context.batch_size] * context.num_steps
- assert context.complete_num_requests_per_step and \
- context.complete_num_requests_per_step > 0, \
- (f"Expected a positive complete_num_requests_per_step argument."
- f"Instead got {context.complete_num_requests_per_step}")
+ assert (
+ context.complete_num_requests_per_step
+ and context.complete_num_requests_per_step > 0
+ ), (
+ f"Expected a positive complete_num_requests_per_step argument."
+ f"Instead got {context.complete_num_requests_per_step}"
+ )
# We start dropping after the first decode step.
step_requests = [
@@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
return step_requests
-def run_profile(context: ProfileContext, csv_output: Optional[str],
- json_output: Optional[str]):
+def run_profile(
+ context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]
+):
print("Run profile with:")
for key, value in asdict(context).items():
print(f" {key} = {value}")
@@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
requests_per_step: list[int] = determine_requests_per_step(context)
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
- context.batch_size, requests_per_step)
+ context.batch_size, requests_per_step
+ )
num_steps_to_profile: int = len(requests_per_step)
max_output_len: int = max(ol_nr.keys())
@@ -186,7 +196,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
top_p=0.95,
# max_tokens is set on a per-request basis.
max_tokens=None,
- ignore_eos=True)
+ ignore_eos=True,
+ )
# Create LLM
llm = LLM(**asdict(context.engine_args))
@@ -199,31 +210,37 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
max_num_seqs = scheduler_config.max_num_seqs
if batch_size * prompt_len > max_num_batched_tokens:
- print(f"ERROR: chosen batch_size * prompt_len "
- f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
- f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
- f"and therefore cannot be run in a single profile step, please "
- f"choose a smaller batch size or prompt length, or increase "
- f"--max-num-batched-tokens")
+ print(
+ f"ERROR: chosen batch_size * prompt_len "
+ f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
+ f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
+ f"and therefore cannot be run in a single profile step, please "
+ f"choose a smaller batch size or prompt length, or increase "
+ f"--max-num-batched-tokens"
+ )
sys.exit(-1)
if batch_size > max_num_seqs:
print(
f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
- f"single profile step, please choose a smaller batch size")
+ f"single profile step, please choose a smaller batch size"
+ )
sys.exit(-1)
- print("llm.llm_engine.model_config.max_model_len: ",
- llm.llm_engine.model_config.max_model_len)
+ print(
+ "llm.llm_engine.model_config.max_model_len: ",
+ llm.llm_engine.model_config.max_model_len,
+ )
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
- print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
- f"{max_output_len} = {prompt_len + max_output_len}) is larger "
- f"than the model's max_model_len ({max_model_len}), please "
- f"choose a smaller prompt_len or max_output_len, or increase "
- f"--max-model-len")
+ print(
+ f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
+ f"{max_output_len} = {prompt_len + max_output_len}) is larger "
+ f"than the model's max_model_len ({max_model_len}), please "
+ f"choose a smaller prompt_len or max_output_len, or increase "
+ f"--max-model-len"
+ )
sys.exit(-1)
def add_requests():
-
def get_output_len_generator() -> Generator[int, Any, Any]:
for output_len, num_reqs in ol_nr.items():
for _ in range(num_reqs):
@@ -234,13 +251,15 @@ def get_output_len_generator() -> Generator[int, Any, Any]:
sampling_params.max_tokens = next(output_len_generator)
assert isinstance(sampling_params.max_tokens, int)
- prompt_token_ids = torch.randint(llm.get_tokenizer().vocab_size,
- size=(prompt_len, )).tolist()
+ prompt_token_ids = torch.randint(
+ llm.get_tokenizer().vocab_size, size=(prompt_len,)
+ ).tolist()
llm.llm_engine.add_request(
request_id=f"seq{i}",
- prompt={'prompt_token_ids': prompt_token_ids},
- params=sampling_params)
+ prompt={"prompt_token_ids": prompt_token_ids},
+ params=sampling_params,
+ )
def abort_requests():
for i in range(batch_size):
@@ -261,10 +280,8 @@ def abort_requests():
decode_profs = []
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
- num_running_seqs = llm.llm_engine.scheduler[
- 0].get_num_unfinished_seq_groups()
- with layerwise_profile(
- num_running_seqs=num_running_seqs) as decode_prof:
+ num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups()
+ with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof:
llm.llm_engine.step()
decode_profs.append(decode_prof)
@@ -274,8 +291,7 @@ def abort_requests():
LINE_WIDTH = 80
print("=" * LINE_WIDTH)
- print(f"= Prefill Model Table "
- f"(prompt_len={prompt_len}, batch_size={batch_size})")
+ print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_model_table()
@@ -283,16 +299,17 @@ def abort_requests():
if has_decode:
print()
print("=" * LINE_WIDTH)
- print(f"= First Decode Step Model Table "
- f"(prompt_len={prompt_len}, batch_size={batch_size})")
+ print(
+ f"= First Decode Step Model Table "
+ f"(prompt_len={prompt_len}, batch_size={batch_size})"
+ )
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_model_table()
print()
print("=" * LINE_WIDTH)
- print(f"= Prefill Summary Table "
- f"(prompt_len={prompt_len}, batch_size={batch_size})")
+ print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_summary_table()
@@ -300,25 +317,32 @@ def abort_requests():
if has_decode:
print()
print("=" * LINE_WIDTH)
- print(f"= First Decode Step Summary Table "
- f"(prompt_len={prompt_len}, batch_size={batch_size})")
+ print(
+ f"= First Decode Step Summary Table "
+ f"(prompt_len={prompt_len}, batch_size={batch_size})"
+ )
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_summary_table()
if csv_output:
- csv_filename_base = csv_output[:-4] \
- if csv_output.endswith('.csv') else csv_output
+ csv_filename_base = (
+ csv_output[:-4] if csv_output.endswith(".csv") else csv_output
+ )
prefill_results.export_model_stats_table_csv(
- csv_filename_base + "_prefill_model_table.csv")
+ csv_filename_base + "_prefill_model_table.csv"
+ )
prefill_results.export_summary_stats_table_csv(
- csv_filename_base + "_prefill_summary_table.csv")
+ csv_filename_base + "_prefill_summary_table.csv"
+ )
if has_decode:
- decode_results_list[0].export_model_stats_table_csv(\
- csv_filename_base + "_decode_model_table.csv")
+ decode_results_list[0].export_model_stats_table_csv(
+ csv_filename_base + "_decode_model_table.csv"
+ )
decode_results_list[0].export_summary_stats_table_csv(
- csv_filename_base + "_decode_summary_table.csv")
+ csv_filename_base + "_decode_summary_table.csv"
+ )
if json_output:
cuda_devices = [
@@ -332,7 +356,7 @@ def abort_requests():
"torch_version": f"{torch.__version__}",
"torch_cuda_version": f"{torch.version.cuda}",
"cuda_devices": f"{cuda_devices}",
- **asdict(context)
+ **asdict(context),
},
"prefill": prefill_results.convert_stats_to_dict(),
}
@@ -342,8 +366,9 @@ def abort_requests():
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
# Add .json to json_output filename if it doesn't exist already.
- json_output_file = json_output if json_output.endswith(
- '.json') else json_output + '.json'
+ json_output_file = (
+ json_output if json_output.endswith(".json") else json_output + ".json"
+ )
with open(json_output_file, "w+") as f:
json.dump(json_dict, f, indent=2)
pass
@@ -351,16 +376,21 @@ def abort_requests():
if context.save_chrome_traces_folder is not None:
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
prefill_prof.profiler.export_chrome_trace(
- context.save_chrome_traces_folder + "/prefill.json")
+ context.save_chrome_traces_folder + "/prefill.json"
+ )
for idx, decode_prof in enumerate(decode_profs):
decode_prof.profiler.export_chrome_trace(
- context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
- print("Traces saved as prefill.json and decode_1.json, etc."
- f" in folder {context.save_chrome_traces_folder}")
+ context.save_chrome_traces_folder + f"/decode_{idx + 1}.json"
+ )
+ print(
+ "Traces saved as prefill.json and decode_1.json, etc."
+ f" in folder {context.save_chrome_traces_folder}"
+ )
def parse_args():
- parser = FlexibleArgumentParser(description="""
+ parser = FlexibleArgumentParser(
+ description="""
Profile a model
example:
@@ -384,7 +414,8 @@ def parse_args():
--output-directory profile_breakdown --plot-metric pct_cuda_time
```
""",
- formatter_class=RawTextHelpFormatter)
+ formatter_class=RawTextHelpFormatter,
+ )
parser.add_argument(
"--csv",
type=str,
@@ -393,59 +424,68 @@ def parse_args():
"filename, will create _prefill_model_table.csv, "
"_prefill_summary_table.csv, "
"_decode_model_table.csv, and "
- "_decode_summary_table.csv")
+ "_decode_summary_table.csv",
+ )
parser.add_argument(
"--json",
type=str,
default=None,
- help="Export the results as a json file. This should be the filename")
- parser.add_argument("--save-chrome-traces-folder",
- type=str,
- help="Save chrome traces for the prefill and decode "
- "will save traces as prefill.json and decode_1.json, "
- "etc. inside this folder")
+ help="Export the results as a json file. This should be the filename",
+ )
+ parser.add_argument(
+ "--save-chrome-traces-folder",
+ type=str,
+ help="Save chrome traces for the prefill and decode "
+ "will save traces as prefill.json and decode_1.json, "
+ "etc. inside this folder",
+ )
parser.add_argument(
"--prompt-len",
type=int,
default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched "
- f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
- parser.add_argument("--batch-size",
- type=int,
- default=BATCH_SIZE_DEFAULT,
- help=f"Number of requests to run as a single batch, "
- f"default={BATCH_SIZE_DEFAULT}")
+ f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=BATCH_SIZE_DEFAULT,
+ help=f"Number of requests to run as a single batch, "
+ f"default={BATCH_SIZE_DEFAULT}",
+ )
subparsers = parser.add_subparsers(dest="cmd")
run_num_steps_parser = subparsers.add_parser(
- "run_num_steps",
- help="This variation profiles n engine.step() invocations.")
+ "run_num_steps", help="This variation profiles n engine.step() invocations."
+ )
run_num_steps_parser.add_argument(
- '-n',
- '--num-steps',
+ "-n",
+ "--num-steps",
type=int,
help="Number of engine steps to profile.\n"
"Setting it to 1, profiles only the prefill step.\n"
"Setting it to 2, profiles the prefill and first decode step\n"
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
- "and so on ...")
+ "and so on ...",
+ )
run_to_completion_parser = subparsers.add_parser(
"run_to_completion",
help="This variation profiles all the engine.step() invocations"
- "until the engine exhausts all submitted requests.")
+ "until the engine exhausts all submitted requests.",
+ )
run_to_completion_parser.add_argument(
- '-n',
- '--complete-num-requests-per-step',
+ "-n",
+ "--complete-num-requests-per-step",
type=int,
- help=
- "Complete complete_num_requests_per_step requests every decode step."
+ help="Complete complete_num_requests_per_step requests every decode step."
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
"the profiler is run for 6 engine steps, with the steps processing, "
"128, 128, 96, 64, 32, 1 requests respectively.\n"
"Note that we tack-on a one-request step at the end as it is often "
- "useful.")
+ "useful.",
+ )
EngineArgs.add_cli_args(parser)
@@ -459,7 +499,8 @@ def main(args):
k: v
for k, v in vars(args).items()
if k in inspect.signature(ProfileContext).parameters
- })
+ },
+ )
run_profile(context, csv_output=args.csv, json_output=args.json)
diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py
index 61da4705e18e..82737d538df4 100644
--- a/examples/offline_inference/profiling_tpu/profiling.py
+++ b/examples/offline_inference/profiling_tpu/profiling.py
@@ -31,18 +31,16 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len,
)
print(sampling_params)
- dummy_prompt_token_ids = np.random.randint(10000,
- size=(args.batch_size,
- args.input_len))
- dummy_prompts: list[PromptType] = [{
- "prompt_token_ids": batch
- } for batch in dummy_prompt_token_ids.tolist()]
+ dummy_prompt_token_ids = np.random.randint(
+ 10000, size=(args.batch_size, args.input_len)
+ )
+ dummy_prompts: list[PromptType] = [
+ {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
+ ]
def run_to_completion():
start_time = time.perf_counter()
- llm.generate(dummy_prompts,
- sampling_params=sampling_params,
- use_tqdm=False)
+ llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
end_time = time.perf_counter()
latency = end_time - start_time
return latency
@@ -58,10 +56,9 @@ def run_to_completion():
profile_dir = args.profile_result_dir
print(f"Profiling (results will be saved to '{profile_dir}')...")
# Enable tracing on server
- xp.trace_detached("localhost:9012",
- profile_dir,
- delay_ms=DELAY_MS,
- duration_ms=DURATION_MS)
+ xp.trace_detached(
+ "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
+ )
if DELAY_MS == 0:
time.sleep(1.0)
profile_latencies = []
@@ -72,30 +69,36 @@ def run_to_completion():
return
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = FlexibleArgumentParser(
- description='Benchmark the latency of processing a single batch of '
- 'requests till completion.')
- parser.add_argument('--input-len', type=int, default=32)
- parser.add_argument('--output-len', type=int, default=128)
- parser.add_argument('--batch-size', type=int, default=8)
- parser.add_argument('--num-iters-warmup',
- type=int,
- default=5,
- help='Number of iterations to run for warmup.')
- parser.add_argument('--num-iters',
- type=int,
- default=1,
- help='Number of iterations to run for profiling.')
+ description="Benchmark the latency of processing a single batch of "
+ "requests till completion."
+ )
+ parser.add_argument("--input-len", type=int, default=32)
+ parser.add_argument("--output-len", type=int, default=128)
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument(
+ "--num-iters-warmup",
+ type=int,
+ default=5,
+ help="Number of iterations to run for warmup.",
+ )
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=1,
+ help="Number of iterations to run for profiling.",
+ )
parser.add_argument(
- '--profile-result-dir',
+ "--profile-result-dir",
type=str,
default="profiles",
- help=
- ('path to save the pytorch profiler output. Can be visualized '
- 'with ui.perfetto.dev or Tensorboard '
- '(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).'
- ))
+ help=(
+ "path to save the pytorch profiler output. Can be visualized "
+ "with ui.perfetto.dev or Tensorboard "
+ "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
+ ),
+ )
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py
index 52b6e977eaa2..eb58a11cb6fa 100644
--- a/examples/offline_inference/qwen2_5_omni/only_thinker.py
+++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
-This example shows how to use vLLM for running offline inference
+This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""
@@ -26,50 +26,53 @@ class QueryResult(NamedTuple):
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
- "generating text and speech.")
+ "generating text and speech."
+)
def get_mixed_modalities_query() -> QueryResult:
- question = ("What is recited in the audio? "
- "What is the content of this image? Why is this video funny?")
- prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
- "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
- "<|vision_bos|><|IMAGE|><|vision_eos|>"
- "<|vision_bos|><|VIDEO|><|vision_eos|>"
- f"{question}<|im_end|>\n"
- f"<|im_start|>assistant\n")
+ question = (
+ "What is recited in the audio? "
+ "What is the content of this image? Why is this video funny?"
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
+ "<|vision_bos|><|IMAGE|><|vision_eos|>"
+ "<|vision_bos|><|VIDEO|><|vision_eos|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
- "audio":
- AudioAsset("mary_had_lamb").audio_and_sample_rate,
- "image":
- ImageAsset("cherry_blossom").pil_image.convert("RGB"),
- "video":
- VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
+ "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ "image": ImageAsset("cherry_blossom").pil_image.convert("RGB"),
+ "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
- limit_mm_per_prompt={
- "audio": 1,
- "image": 1,
- "video": 1
- },
+ limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query() -> QueryResult:
- question = ("Describe the content of the video, "
- "then convert what the baby say into text.")
- prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
- "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
- f"{question}<|im_end|>\n"
- f"<|im_start|>assistant\n")
+ question = (
+ "Describe the content of the video, then convert what the baby say into text."
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
- assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
- "Please launch this example with "
- "`VLLM_USE_V1=0`.")
+ assert not envs.VLLM_USE_V1, (
+ "V1 does not support use_audio_in_video. "
+ "Please launch this example with "
+ "`VLLM_USE_V1=0`."
+ )
return QueryResult(
inputs={
"prompt": prompt,
@@ -81,20 +84,19 @@ def get_use_audio_in_video_query() -> QueryResult:
"use_audio_in_video": True,
},
},
- limit_mm_per_prompt={
- "audio": 1,
- "video": 1
- },
+ limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
- prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
- "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
- "<|audio_bos|><|AUDIO|><|audio_eos|>"
- f"{question}<|im_end|>\n"
- f"<|im_start|>assistant\n")
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
+ "<|audio_bos|><|AUDIO|><|audio_eos|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
return QueryResult(
inputs={
"prompt": prompt,
@@ -122,18 +124,19 @@ def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]()
- llm = LLM(model=model_name,
- max_model_len=5632,
- max_num_seqs=5,
- limit_mm_per_prompt=query_result.limit_mm_per_prompt,
- seed=args.seed)
+ llm = LLM(
+ model=model_name,
+ max_model_len=5632,
+ max_num_seqs=5,
+ limit_mm_per_prompt=query_result.limit_mm_per_prompt,
+ seed=args.seed,
+ )
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
- outputs = llm.generate(query_result.inputs,
- sampling_params=sampling_params)
+ outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
@@ -142,18 +145,23 @@ def main(args):
def parse_args():
parser = FlexibleArgumentParser(
- description='Demo on using vLLM for offline inference with '
- 'audio language models')
- parser.add_argument('--query-type',
- '-q',
- type=str,
- default="mixed_modalities",
- choices=query_map.keys(),
- help='Query type.')
- parser.add_argument("--seed",
- type=int,
- default=None,
- help="Set the seed when initializing `vllm.LLM`.")
+ description="Demo on using vLLM for offline inference with "
+ "audio language models"
+ )
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="mixed_modalities",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
return parser.parse_args()
diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py
index 64a1f4c54b67..856a35b0e59b 100644
--- a/examples/offline_inference/qwen_1m.py
+++ b/examples/offline_inference/qwen_1m.py
@@ -17,10 +17,10 @@ def load_prompt() -> str:
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com"
- "/Qwen2.5-1M/test-data/600k.txt",
- timeout=5) as response:
- prompt = response.read().decode('utf-8')
+ "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
+ timeout=5,
+ ) as response:
+ prompt = response.read().decode("utf-8")
return prompt
@@ -41,18 +41,22 @@ def process_requests(llm: LLM, prompts: list[str]) -> None:
for output in outputs:
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
- print(f"Prompt length: {len(prompt_token_ids)}, "
- f"Generated text: {generated_text!r}")
+ print(
+ f"Prompt length: {len(prompt_token_ids)}, "
+ f"Generated text: {generated_text!r}"
+ )
# Create an LLM.
def initialize_engine() -> LLM:
- llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
- max_model_len=1048576,
- tensor_parallel_size=4,
- enforce_eager=True,
- enable_chunked_prefill=True,
- max_num_batched_tokens=131072)
+ llm = LLM(
+ model="Qwen/Qwen2.5-7B-Instruct-1M",
+ max_model_len=1048576,
+ tensor_parallel_size=4,
+ enforce_eager=True,
+ enable_chunked_prefill=True,
+ max_num_batched_tokens=131072,
+ )
return llm
@@ -62,5 +66,5 @@ def main():
process_requests(llm, [prompt])
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py
index e0ed0ac49754..a8f6977e29a4 100644
--- a/examples/offline_inference/rlhf.py
+++ b/examples/offline_inference/rlhf.py
@@ -12,6 +12,7 @@
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
"""
+
import os
import ray
@@ -26,7 +27,6 @@
class MyLLM(LLM):
-
def __init__(self, *args, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
@@ -89,8 +89,7 @@ def __init__(self, *args, **kwargs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}\n"
- f"Generated text: {generated_text!r}")
+ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# set up the communication between the training process
@@ -98,11 +97,13 @@ def __init__(self, *args, **kwargs):
master_address = get_ip()
master_port = get_open_port()
-handle = llm.collective_rpc.remote("init_weight_update_group",
- args=(master_address, master_port, 1, 3))
+handle = llm.collective_rpc.remote(
+ "init_weight_update_group", args=(master_address, master_port, 1, 3)
+)
-model_update_group = stateless_init_process_group(master_address, master_port,
- 0, 3, torch.device("cuda:0"))
+model_update_group = stateless_init_process_group(
+ master_address, master_port, 0, 3, torch.device("cuda:0")
+)
ray.get(handle)
# simulate training, modify the weights of the model.
@@ -111,8 +112,7 @@ def __init__(self, *args, **kwargs):
# sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters():
- handle = llm.collective_rpc.remote("update_weight",
- args=(name, p.dtype, p.shape))
+ handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
@@ -126,6 +126,5 @@ def __init__(self, *args, **kwargs):
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}\n"
- f"Generated text: {generated_text!r}")
+ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py
index 3ceac0fa2e20..76eafdca1f6c 100644
--- a/examples/offline_inference/rlhf_colocate.py
+++ b/examples/offline_inference/rlhf_colocate.py
@@ -9,6 +9,7 @@
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
"""
+
import os
import ray
@@ -20,7 +21,6 @@
class MyLLM(LLM):
-
def __init__(self, *args, bundle_indices: list, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
@@ -29,17 +29,16 @@ def __init__(self, *args, bundle_indices: list, **kwargs):
# every worker will use 0.4 GPU, so that we can schedule
# 2 instances on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
- os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(
- map(str, bundle_indices))
+ os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}")
super().__init__(*args, **kwargs)
class RayTrainingActor:
-
def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from transformers import AutoModelForCausalLM
+
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
for name, p in self.model.named_parameters():
@@ -48,6 +47,7 @@ def __init__(self):
# the argument for get_device_uuid is the index
# of the GPU in the visible devices.
from vllm.platforms import current_platform
+
self.device_uuid = current_platform.get_device_uuid(0)
def report_device_id(self) -> str:
@@ -55,6 +55,7 @@ def report_device_id(self) -> str:
def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
+
data = {}
for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights
@@ -101,7 +102,7 @@ def get_weight_ipc_handles(self):
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
-for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
+for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# IMPORTANT: when creating vLLM instances, we need to
# make sure there are no GPU activities on the target GPUs,
# otherwise, they will interfere with the vLLM memory profiling,
@@ -128,7 +129,8 @@ def get_weight_ipc_handles(self):
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
- ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())))
+ ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
+ )
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# check the placement
@@ -147,9 +149,10 @@ def get_weight_ipc_handles(self):
print("update the weights of the inference engines")
for llm in inference_engines:
ray.get(
- llm.collective_rpc.remote("update_weights_from_ipc_handles",
- args=(ipc_handles, )))
+ llm.collective_rpc.remote(
+ "update_weights_from_ipc_handles", args=(ipc_handles,)
+ )
+ )
print("check if the weights are updated")
for llm in inference_engines:
- assert ray.get(
- llm.collective_rpc.remote("check_weights_changed", args=tuple()))
+ assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py
index 11b73b7c4a0a..3461af707eba 100644
--- a/examples/offline_inference/rlhf_utils.py
+++ b/examples/offline_inference/rlhf_utils.py
@@ -2,21 +2,20 @@
import torch
-def stateless_init_process_group(master_address, master_port, rank, world_size,
- device):
+def stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
- the data-plane communication (NCCL) between external (train processes)
+ the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
- pg = StatelessProcessGroup.create(host=master_address,
- port=master_port,
- rank=rank,
- world_size=world_size)
+
+ pg = StatelessProcessGroup.create(
+ host=master_address, port=master_port, rank=rank, world_size=world_size
+ )
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
@@ -31,9 +30,11 @@ class WorkerExtension:
should pass the full qualified name as `worker_extension_cls` argument.
"""
- def init_weight_update_group(self, master_address, master_port,
- rank_offset, world_size):
+ def init_weight_update_group(
+ self, master_address, master_port, rank_offset, world_size
+ ):
from vllm.distributed.parallel_state import get_world_group
+
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
@@ -45,9 +46,9 @@ def init_weight_update_group(self, master_address, master_port,
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
- self.model_update_group.broadcast(weight,
- src=0,
- stream=torch.cuda.current_stream())
+ self.model_update_group.broadcast(
+ weight, src=0, stream=torch.cuda.current_stream()
+ )
self.model_runner.model.load_weights(weights=[(name, weight)])
@@ -59,8 +60,7 @@ def check_weights_changed(self):
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
- weights_updated = weights_updated and torch.allclose(
- p, torch.zeros_like(p))
+ weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated
@@ -76,6 +76,7 @@ class ColocateWorkerExtension:
def report_device_id(self) -> str:
from vllm.platforms import current_platform
+
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
@@ -100,6 +101,5 @@ def check_weights_changed(self):
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
- weights_updated = weights_updated and torch.allclose(
- p, torch.zeros_like(p))
+ weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated
diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py
index 338380cc9684..860fe2b5fe06 100644
--- a/examples/offline_inference/save_sharded_state.py
+++ b/examples/offline_inference/save_sharded_state.py
@@ -21,6 +21,7 @@
tensor_parallel_size=8,
)
"""
+
import dataclasses
import os
import shutil
@@ -33,18 +34,18 @@
def parse_args():
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
- parser.add_argument("--output",
- "-o",
- required=True,
- type=str,
- help="path to output checkpoint")
- parser.add_argument("--file-pattern",
- type=str,
- help="string pattern of saved filenames")
- parser.add_argument("--max-file-size",
- type=str,
- default=5 * 1024**3,
- help="max size (in bytes) of each safetensors file")
+ parser.add_argument(
+ "--output", "-o", required=True, type=str, help="path to output checkpoint"
+ )
+ parser.add_argument(
+ "--file-pattern", type=str, help="string pattern of saved filenames"
+ )
+ parser.add_argument(
+ "--max-file-size",
+ type=str,
+ default=5 * 1024**3,
+ help="max size (in bytes) of each safetensors file",
+ )
return parser.parse_args()
@@ -68,23 +69,23 @@ def main(args):
# For V1 engine, we need to use engine_core.save_sharded_state
print("Using V1 engine save path")
llm.llm_engine.engine_core.save_sharded_state(
- path=args.output,
- pattern=args.file_pattern,
- max_size=args.max_file_size)
+ path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
+ )
else:
# For V0 engine
print("Using V0 engine save path")
model_executor = llm.llm_engine.model_executor
- model_executor.save_sharded_state(path=args.output,
- pattern=args.file_pattern,
- max_size=args.max_file_size)
+ model_executor.save_sharded_state(
+ path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
+ )
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
- shutil.copytree(os.path.join(model_path, file),
- os.path.join(args.output, file))
+ shutil.copytree(
+ os.path.join(model_path, file), os.path.join(args.output, file)
+ )
else:
shutil.copy(os.path.join(model_path, file), args.output)
diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py
index 363b500e0adf..9ed7299606b7 100644
--- a/examples/offline_inference/structured_outputs.py
+++ b/examples/offline_inference/structured_outputs.py
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
"""
-This file demonstrates the example usage of guided decoding
-to generate structured outputs using vLLM. It shows how to apply
-different guided decoding techniques such as Choice, Regex, JSON schema,
-and Grammar to produce structured and formatted results
+This file demonstrates the example usage of guided decoding
+to generate structured outputs using vLLM. It shows how to apply
+different guided decoding techniques such as Choice, Regex, JSON schema,
+and Grammar to produce structured and formatted results
based on specific prompts.
"""
@@ -15,20 +15,20 @@
from vllm.sampling_params import GuidedDecodingParams
# Guided decoding by Choice (list of possible options)
-guided_decoding_params_choice = GuidedDecodingParams(
- choice=["Positive", "Negative"])
-sampling_params_choice = SamplingParams(
- guided_decoding=guided_decoding_params_choice)
+guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
+sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
- guided_decoding=guided_decoding_params_regex, stop=["\n"])
+ guided_decoding=guided_decoding_params_regex, stop=["\n"]
+)
prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
- "alan.turing@enigma.com\n")
+ "alan.turing@enigma.com\n"
+)
# Guided decoding by JSON using Pydantic schema
@@ -47,10 +47,11 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
-sampling_params_json = SamplingParams(
- guided_decoding=guided_decoding_params_json)
-prompt_json = ("Generate a JSON with the brand, model and car_type of"
- "the most iconic car from the 90's")
+sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
+prompt_json = (
+ "Generate a JSON with the brand, model and car_type of"
+ "the most iconic car from the 90's"
+)
# Guided decoding by Grammar
simplified_sql_grammar = """
@@ -61,12 +62,11 @@ class CarDescription(BaseModel):
condition ::= column "= " number
number ::= "1 " | "2 "
"""
-guided_decoding_params_grammar = GuidedDecodingParams(
- grammar=simplified_sql_grammar)
-sampling_params_grammar = SamplingParams(
- guided_decoding=guided_decoding_params_grammar)
-prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
- "from the 'users' table.")
+guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
+sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
+prompt_grammar = (
+ "Generate an SQL query to show the 'username' and 'email'from the 'users' table."
+)
def format_output(title: str, output: str):
@@ -90,8 +90,7 @@ def main():
json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output)
- grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
- llm)
+ grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
format_output("Guided decoding by Grammar", grammar_output)
diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py
index bb61a0a29e32..2fa49c0835e3 100644
--- a/examples/offline_inference/torchrun_example.py
+++ b/examples/offline_inference/torchrun_example.py
@@ -45,8 +45,7 @@
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}\n"
- f"Generated text: {generated_text!r}\n")
+ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py
index 71cd88f2788a..e4a75b3f9380 100644
--- a/examples/offline_inference/tpu.py
+++ b/examples/offline_inference/tpu.py
@@ -20,10 +20,12 @@
def main():
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
- llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
- max_num_batched_tokens=64,
- max_num_seqs=4,
- max_model_len=128)
+ llm = LLM(
+ model="Qwen/Qwen2-1.5B-Instruct",
+ max_num_batched_tokens=64,
+ max_num_seqs=4,
+ max_model_len=128,
+ )
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output, answer in zip(outputs, answers):
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index c54f328c7a38..f2bdfe61de10 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -6,6 +6,7 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
+
import os
import random
from contextlib import contextmanager
@@ -48,9 +49,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
- prompts = [(f"<|im_start|>user\n<|img|>{question}"
- "<|im_end|>\n<|im_start|>assistant\n")
- for question in questions]
+ prompts = [
+ (
+ f"<|im_start|>user\n<|img|>{question}"
+ "<|im_end|>\n<|im_start|>assistant\n"
+ )
+ for question in questions
+ ]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
@@ -134,8 +139,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
)
prompts = [
- f"<|User|>: \n{question}\n\n<|Assistant|>:"
- for question in questions
+ f"<|User|>: \n{question}\n\n<|Assistant|>:" for question in questions
]
return ModelRequestData(
@@ -197,9 +201,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
- prompts = [("user\n"
- f"{question}\n"
- "model\n") for question in questions]
+ prompts = [
+ (
+ "user\n"
+ f"{question}\n"
+ "model\n"
+ )
+ for question in questions
+ ]
return ModelRequestData(
engine_args=engine_args,
@@ -224,7 +233,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
- {question}<|assistant|>" for question in questions
+ {question}<|assistant|>"
+ for question in questions
]
stop_token_ids = [151329, 151336, 151338]
@@ -249,15 +259,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
- tokenizer = AutoTokenizer.from_pretrained(model_name,
- trust_remote_code=True)
- messages = [[{
- 'role': 'user',
- 'content': f"\n{question}"
- }] for question in questions]
- prompts = tokenizer.apply_chat_template(messages,
- tokenize=False,
- add_generation_prompt=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ messages = [
+ [{"role": "user", "content": f"\n{question}"}] for question in questions
+ ]
+ prompts = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
@@ -283,15 +291,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={
- "size": {
- "longest_edge": 3 * 364
- },
+ "size": {"longest_edge": 3 * 364},
},
limit_mm_per_prompt={modality: 1},
)
- prompts = [(
- f"<|begin_of_text|>User:{question}\nAssistant:"
- ) for question in questions]
+ prompts = [
+ (f"<|begin_of_text|>User:{question}\nAssistant:")
+ for question in questions
+ ]
return ModelRequestData(
engine_args=engine_args,
@@ -310,9 +317,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
enforce_eager=True,
mm_processor_kwargs={
- "max_image_size": {
- "longest_edge": 384
- },
+ "max_image_size": {"longest_edge": 384},
},
limit_mm_per_prompt={modality: 1},
)
@@ -340,15 +345,13 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
- tokenizer = AutoTokenizer.from_pretrained(model_name,
- trust_remote_code=True)
- messages = [[{
- 'role': 'user',
- 'content': f"\n{question}"
- }] for question in questions]
- prompts = tokenizer.apply_chat_template(messages,
- tokenize=False,
- add_generation_prompt=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ messages = [
+ [{"role": "user", "content": f"\n{question}"}] for question in questions
+ ]
+ prompts = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
# Stop tokens for InternVL
# models variants may have different stop tokens
@@ -371,7 +374,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
prompts = [
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
f"<|media_pad|><|media_end|>{question}<|im_end|>"
- "<|im_assistant|>assistant<|im_middle|>" for question in questions
+ "<|im_assistant|>assistant<|im_middle|>"
+ for question in questions
]
engine_args = EngineArgs(
@@ -391,9 +395,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
- prompts = [
- f"USER: \n{question}\nASSISTANT:" for question in questions
- ]
+ prompts = [f"USER: \n{question}\nASSISTANT:" for question in questions]
engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf",
@@ -426,13 +428,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
# LlaVA-NeXT-Video
# Currently only support for video input
-def run_llava_next_video(questions: list[str],
- modality: str) -> ModelRequestData:
+def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "video"
- prompts = [
- f"USER: