From a7a6e438d7a79ab57694ef58cc83ccfb2746e418 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Aug 2024 12:10:23 -0700 Subject: [PATCH 001/116] [Benchmark] Add async throughput benchmark Like benchmark_throughput but using AsyncLLMEngine rather than LLM --- benchmarks/benchmark_throughput_async.py | 479 +++++++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 35 +- 2 files changed, 505 insertions(+), 9 deletions(-) create mode 100644 benchmarks/benchmark_throughput_async.py diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py new file mode 100644 index 000000000000..0b9c2e16a370 --- /dev/null +++ b/benchmarks/benchmark_throughput_async.py @@ -0,0 +1,479 @@ +"""Benchmark offline inference throughput.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.entrypoints.openai.api_server import build_async_engine_client_from_engine_args +from vllm.utils import merge_async_iterators +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +async def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, +) -> float: + from vllm import LLM, SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + engine_use_ray=False, + disable_log_requests=True, + ) + + decoupled = True + + async with build_async_engine_client_from_engine_args(engine_args, + not decoupled) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[Tuple[str, int, int]], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [prompt for prompt, _, _ in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + coro = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) + + elapsed_time = uvloop.run(coro) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') + parser.add_argument( + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") + parser.add_argument("--use-v2-block-manager", + action='store_true', + help="Enable block manager v2.") + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559..e99e4bd95108 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -96,6 +96,22 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[Optional[AsyncEngineClient]]: """ Create AsyncEngineClient, either: - in-process using the AsyncLLMEngine Directly @@ -104,22 +120,21 @@ async def build_async_engine_client( Returns the Client or None if the creation failed. """ - # Context manager to handle async_engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - global engine_args - engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler global async_engine_client # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model, args.trust_remote_code, - args.quantization) - or args.disable_frontend_multiprocessing): + if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, + engine_args.quantization) + or disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - yield async_engine_client + try: + yield async_engine_client + finally: + async_engine_client.shutdown_background_loop() + async_engine_client = None #TODO return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -192,6 +207,8 @@ async def build_async_engine_client( from prometheus_client import multiprocess multiprocess.mark_process_dead(rpc_server_process.pid) + async_engine_client = None #TODO + router = APIRouter() From ce7d15974028679c4d08742bb763489a4a06c004 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Aug 2024 17:25:19 -0700 Subject: [PATCH 002/116] wip --- vllm/engine/async_llm_engine.py | 135 +++++++++++++++++------- vllm/entrypoints/openai/rpc/__init__.py | 5 + vllm/entrypoints/openai/rpc/client.py | 79 ++++++++++---- vllm/entrypoints/openai/rpc/server.py | 36 +++++-- 4 files changed, 186 insertions(+), 69 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37696bf1d9dc..fdad2d18c5b8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -47,7 +47,6 @@ def _log_task_completion(task: asyncio.Task, there is an exception. """ - exception = None try: return_value = task.result() raise AssertionError( @@ -80,8 +79,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, - Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -123,10 +121,11 @@ def _is_raisable(value: Any): class RequestTracker: """Synchronous abstraction for tracking requests.""" - def __init__(self) -> None: + def __init__(self, per_request_streams: bool = True) -> None: + self._per_request_streams = per_request_streams self._request_streams: Dict[str, AsyncStream] = {} self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, + self._new_requests: asyncio.Queue[Tuple[Optional[AsyncStream], dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() @@ -186,14 +185,15 @@ def add_request(self, request_id: str, *, verbose: bool = False, - **engine_add_request_kwargs) -> AsyncStream: + **engine_add_request_kwargs) -> Optional[AsyncStream]: """Add a request to be sent to the engine on the next background loop iteration.""" if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) + stream = AsyncStream(request_id, abort_request) \ + if self._per_request_streams else None self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs @@ -234,13 +234,15 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() - request_id = stream.request_id + request_id = new_request["request_id"] if request_id in finished_requests: # The request has already been aborted. - stream.finish(asyncio.CancelledError) + if stream is not None: + stream.finish(asyncio.CancelledError) finished_requests.discard(request_id) else: - self._request_streams[request_id] = stream + if stream is not None: + self._request_streams[request_id] = stream new_requests.append(new_request) return new_requests, finished_requests @@ -639,7 +641,34 @@ def __init__(self, self._errored_with: Optional[BaseException] = None # Lazy initialized fields - self._request_tracker: RequestTracker + self._request_tracker: RequestTracker = None # type: ignore[assignment] + + self._global_queue: Optional[asyncio.Queue] = None + + async def global_output_generator( + self + ) -> AsyncGenerator[List[Union[RequestOutput, EmbeddingRequestOutput, + Tuple[str, BaseException]]], None]: + """Returns a single generator that streams outputs from all + requests. + + Must be called at most once prior to processing any requests, + and if used, generate() will return None rather than a per-request + stream. + """ + if self._global_queue is not None: + raise RuntimeError( + "global_output_generator can only be called once") + if self._request_tracker is not None: + raise RuntimeError( + "global_output_generator must be called before processing " + "any requests") + + self._global_queue = asyncio.Queue() + + # This runs until the engine is shut down + while True: + yield await self._global_queue.get() @classmethod def _get_executor_cls( @@ -763,6 +792,11 @@ def set_errored(self, exc: Exception) -> None: def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + if self._global_queue is not None: + #TODO clean this up + for request_id in tuple( + self._request_tracker._request_streams.keys()): + self._global_queue.put_nowait((request_id, exc)) async def get_tokenizer( self, @@ -783,7 +817,8 @@ def start_background_loop(self) -> None: if self.is_running: raise RuntimeError("Background loop is already running.") # Initialize the RequestTracker here so it uses the right event loop. - self._request_tracker = RequestTracker() + per_request_streams = self._global_queue is None + self._request_tracker = RequestTracker(per_request_streams) self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) @@ -844,11 +879,14 @@ async def engine_step(self, virtual_engine: int) -> bool: await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use a vLLM specific error for failed validation + request_id = new_request["request_id"] self._request_tracker.process_exception( - new_request["request_id"], + request_id, e, verbose=self.log_requests, ) + if self._global_queue is not None: + self._global_queue.put_nowait((request_id, e)) if aborted_requests: await self._engine_abort(aborted_requests) @@ -859,13 +897,18 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - finished = True + all_finished = True for request_output in request_outputs: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests) - finished = finished and request_output.finished + finished = request_output.finished + if finished or self._global_queue is None: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and finished + + if self._global_queue is not None: + self._global_queue.put_nowait(request_outputs) - return not finished + return not all_finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -950,8 +993,9 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Optional[AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], + None]]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -972,7 +1016,7 @@ async def add_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request) - return stream.generator() + return stream.generator() if stream is not None else None async def generate( self, @@ -982,7 +1026,7 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: + ) -> Optional[AsyncGenerator[RequestOutput, None]]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -1004,6 +1048,9 @@ async def generate( The output `RequestOutput` objects from the LLMEngine for the request. + Unless a global output generator is being used, in which case + this methods will return None. + Details: - If the engine is not running, start the background loop, which iteratively invokes @@ -1047,15 +1094,22 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - inputs, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - ): - yield LLMEngine.validate_output(output, RequestOutput) + maybe_generator = await self.add_request( + request_id, + inputs, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + ) + if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: + return maybe_generator + + async def validating_generator(): + async for output in maybe_generator: + yield LLMEngine.validate_output(output, RequestOutput) + + return validating_generator() async def encode( self, @@ -1125,13 +1179,15 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - inputs, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ): + generator = await self.add_request( + request_id, + inputs, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + ) + assert generator is not None + async for output in generator: yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def abort(self, request_id: str) -> None: @@ -1165,6 +1221,9 @@ def _abort(self, request_id: str) -> None: exception=asyncio.CancelledError, verbose=self.log_requests) + if self._global_queue is not None: + self._global_queue.put_nowait((request_id, asyncio.CancelledError)) + async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index efc7e43afdcc..c4cce036281a 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -17,6 +17,11 @@ VLLM_RPC_ZMQ_HWM = 0 +@dataclass +class RPCOutputStreamRequest: + pass + + @dataclass class RPCGenerateRequest: inputs: PromptInputs diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index c457555c54b9..51fbbfd3b461 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,12 +1,14 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -16,7 +18,9 @@ VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs @@ -141,12 +145,37 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 + self.output_queues: Dict[str, asyncio.Queue] = {} + + self.output_handler = asyncio.create_task(self.run_output_handler()) + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" while True: frames = await socket_from.recv_multipart(copy=False) await socket_to.send_multipart(frames, copy=False) + async def run_output_handler(self): + with self.to_proxy_socket() as socket: + await socket.send_multipart( + (cloudpickle.dumps(RPCOutputStreamRequest()), )) + + # Stream back the results from the RPC Server. + while True: + message: Frame = await socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + async def setup(self): """Setup the client before it starts sending server requests.""" @@ -379,6 +408,9 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue finished = False try: with self.to_proxy_socket() as socket: @@ -392,29 +424,30 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output + ack: Frame = await socket.recv(copy=False) + if len(ack.buffer) != 0: + exception = pickle.loads(ack.buffer) + raise exception + + while not finished: + request_output = await queue.get() + if isinstance(request_output, BaseException): + finished = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + await self.check_health(socket=socket) + except Exception as e: + self._errored = True + logger.exception(repr(e)) + raise request_output + + finished = request_output.finished + yield request_output finally: + self.output_queues.pop(request_id) # Request was canceled by the client. if not finished and not self._errored: await self.abort(request_id) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb68..42a66e35a65b 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -16,7 +16,9 @@ ParallelConfig, SchedulerConfig) from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -102,9 +104,27 @@ async def abort(self, identity, request: RPCAbortRequest): result = e await self.socket.send_multipart((identity, pickle.dumps(result))) + async def stream_outputs(self, identity): + # This runs indefinitely + #TODO handle shutdown + async for outputs in self.engine.global_output_generator(): + # Trim down contents to be equivalent to deltas (other PR for this) + # for output in outputs: + # output.prompt = None + # output.prompt_token_ids = None + # output.prompt_logprobs = None + # for o in output.outputs: + # o.token_ids = [0] + # o.text = " word" + + await self.socket.send_multipart((identity, pickle.dumps(outputs)), + copy=False) + async def generate(self, identity, generate_request: RPCGenerateRequest): + # Empty result to indicate success + result = b'' try: - results_generator = self.engine.generate( + await self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, request_id=generate_request.request_id, @@ -112,13 +132,10 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): trace_headers=generate_request.trace_headers, prompt_adapter_request=generate_request.prompt_adapter_request) - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) + result = pickle.dumps(e) + + await self.socket.send_multipart((identity, result), copy=False) async def check_health(self, identity): try: @@ -156,6 +173,9 @@ def _make_handler_coro(self, identity, request = cloudpickle.loads(message.buffer) + if isinstance(request, RPCOutputStreamRequest): + return self.stream_outputs(identity) + if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) From d99ce6f2c5034c4292ce90d4bbfcaa0ef0502393 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 31 Aug 2024 19:16:16 +0000 Subject: [PATCH 003/116] stash --- vllm/engine/async_llm_engine.py | 10 ++++++---- vllm/entrypoints/openai/rpc/server.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index fdad2d18c5b8..75cc637ca6a1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -159,14 +159,16 @@ def process_request_output(self, if finished: stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish() else: stream = self._request_streams.get(request_id) # Guard against a KeyError which can occur if the request was aborted # while the output was generated - if stream is not None: - stream.put(request_output) - if finished: - stream.finish() + # if stream is not None: + # stream.put(request_output) + # if finished: + # stream.finish() if verbose and finished: logger.info("Finished request %s.", request_id) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 42a66e35a65b..c799ab8a35fa 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -230,6 +230,9 @@ async def run_server_loop(self): async def run_server(server: AsyncEngineRPCServer): + # import pyinstrument + + # with pyinstrument.Profiler(async_mode="disabled") as prof: # Put the server task into the asyncio loop. loop = asyncio.get_running_loop() server_task = loop.create_task(server.run_server_loop()) @@ -249,7 +252,7 @@ def signal_handler() -> None: finally: # Clean up all resources. server.cleanup() - + # prof.write_html("prof-disabled.html", show_all=True) def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): From 8d6b2e9d434908c7c9cfd686bbd8674175175f95 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 17:51:49 +0000 Subject: [PATCH 004/116] remove proxy --- benchmarks/benchmark_throughput_async.py | 2 +- vllm/entrypoints/openai/rpc/client.py | 100 +++++++---------------- vllm/entrypoints/openai/rpc/server.py | 4 +- 3 files changed, 32 insertions(+), 74 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index 0b9c2e16a370..ec4351cebc29 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -120,7 +120,7 @@ async def run_vllm( disable_log_requests=True, ) - decoupled = True + decoupled = False async with build_async_engine_client_from_engine_args(engine_args, not decoupled) as llm: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 51fbbfd3b461..21e8fbbefcbd 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -104,77 +104,36 @@ def __init__(self, rpc_path: str): self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 + self.socket: Socket = self.context.socket(zmq.constants.DEALER) + self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) + self.socket.connect(rpc_path) + self.rpc_path = rpc_path + self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_handler = asyncio.create_task(self.run_output_handler()) + - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" + async def run_output_handler(self): + await self.socket.send_multipart( + (cloudpickle.dumps(RPCOutputStreamRequest()), )) + + # Stream back the results from the RPC Server. while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) + message: Frame = await self.socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) - async def run_output_handler(self): - with self.to_proxy_socket() as socket: - await socket.send_multipart( - (cloudpickle.dumps(RPCOutputStreamRequest()), )) - - # Stream back the results from the RPC Server. - while True: - message: Frame = await socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - for output in request_outputs: - if isinstance(output, tuple): - # Exception case - request_id, output = output - else: - request_id = output.request_id - - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -200,12 +159,11 @@ def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() + self.socket.close() self.context.destroy() @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: + def rpc_get_data_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -221,7 +179,7 @@ def to_proxy_socket(self) -> Iterator[Socket]: socket = self.context.socket(zmq.constants.DEALER) socket.set_hwm(VLLM_RPC_ZMQ_HWM) try: - socket.connect(INPROC_PROXY_PATH) + socket.connect(self.rpc_path) yield socket finally: socket.close(linger=0) @@ -231,7 +189,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: # Ping RPCServer with a request. await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) @@ -280,7 +238,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): # Make a new socket connection. if socket is None: - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: response = await do_rpc_call(socket, request) # Use existing socket connection. @@ -413,7 +371,7 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index c799ab8a35fa..4b1ba19327b7 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -40,9 +40,9 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.ROUTER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) + self.socket.bind(rpc_path) def cleanup(self): """Cleanup all resources.""" From 14f36373e0b549c55e2caa281ffb5e22904545da Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 20:52:21 +0000 Subject: [PATCH 005/116] stash --- benchmarks/benchmark_throughput_async.py | 7 +- examples/openai_completion_client.py | 6 +- vllm/engine/async_llm_engine.py | 1 + vllm/entrypoints/openai/api_server.py | 16 ++- vllm/entrypoints/openai/rpc/__init__.py | 1 + vllm/entrypoints/openai/rpc/client.py | 145 ++++++++++++----------- vllm/utils.py | 1 - 7 files changed, 95 insertions(+), 82 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index ec4351cebc29..54eed0f4de78 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -1,5 +1,6 @@ """Benchmark offline inference throughput.""" import argparse +import asyncio import json import random import time @@ -120,7 +121,7 @@ async def run_vllm( disable_log_requests=True, ) - decoupled = False + decoupled = True async with build_async_engine_client_from_engine_args(engine_args, not decoupled) as llm: @@ -143,15 +144,15 @@ async def run_vllm( generators = [] start = time.perf_counter() for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + # generator = await llm.generate(prompt, sp, request_id=f"test{i}") generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) + generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass end = time.perf_counter() return end - start - def run_hf( requests: List[Tuple[str, int, int]], model: str, diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d34..13f98d322036 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -14,14 +14,12 @@ model = models.data[0].id # Completion API -stream = False +stream = True completion = client.completions.create( model=model, prompt="A robot may not injure a human being", - echo=False, - n=2, stream=stream, - logprobs=3) + max_tokens=1000) print("Completion results:") if stream: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 75cc637ca6a1..e4bc40150af1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1104,6 +1104,7 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, ) + return maybe_generator if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: return maybe_generator diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e99e4bd95108..34daf5bcb35f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -39,7 +39,8 @@ TokenizeResponse) # yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server +# from vllm.entrypoints.openai.rpc.server import run_rpc_server +from vllm.engine.llm_engine2 import run_rpc_server from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -84,8 +85,9 @@ async def _force_log(): while True: await asyncio.sleep(10) await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: + + # if not engine_args.disable_log_stats: + if False: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) @@ -169,9 +171,11 @@ async def build_async_engine_client_from_engine_args( context = multiprocessing.get_context("spawn") # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) + # rpc_server_process = context.Process( + # target=run_rpc_server, + # args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) + + rpc_server_process = context.Process(target=run_rpc_server, args=(engine_args,)) rpc_server_process.start() logger.info("Started engine process with PID %d", rpc_server_process.pid) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index c4cce036281a..4bf24bdc37f4 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -49,6 +49,7 @@ class RPCUtilityRequest(Enum): IS_TRACING_ENABLED = 9 START_PROFILE = 10 STOP_PROFILE = 11 + CLIENT_IS_READY = 11 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 21e8fbbefcbd..c71f25084422 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -104,24 +104,47 @@ def __init__(self, rpc_path: str): self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False - # IPC connection to RPC Server (uses unix sockets). - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - self.rpc_path = rpc_path + self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.new_req_socket.connect("ipc:///tmp/new_req_socket") + + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect("ipc:///tmp/output_socket") + + # self.data_socket: Socket = self.context.socket(zmq.constants.DEALER) + # self.data_socket.connect("ipc:///tmp/data_socket") self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + # Connect to the RPCServer via the proxy. + + # Raise a sensible error if the client was already closed. + # This can happen if a server shutdown is triggered but some coroutines + # are still running requests. + # There should not be a race condition with this check because we don't + # yield to the event loop between here and opening the socket. + if self.context.closed: + raise RPCClientClosedError("The ZMQ client has already shut down") + + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect("ipc:///tmp/data_socket") + yield socket + finally: + socket.close(linger=0) async def run_output_handler(self): - await self.socket.send_multipart( - (cloudpickle.dumps(RPCOutputStreamRequest()), )) + # await self.socket.send_multipart( + # (cloudpickle.dumps(RPCOutputStreamRequest()), )) # Stream back the results from the RPC Server. while True: - message: Frame = await self.socket.recv(copy=False) + message: Frame = await self.output_socket.recv(copy=False) request_outputs = pickle.loads(message.buffer) for output in request_outputs: @@ -155,69 +178,50 @@ async def setup(self): enable_lora=bool(await self._get_lora_config_rpc()), ) + await self._notify_ready() + def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. - self.socket.close() - self.context.destroy() - - @contextmanager - def rpc_get_data_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. + self.context.destroy(linger=0) - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(self.rpc_path) - yield socket - finally: - socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.rpc_get_data_socket() as socket: + with self.get_data_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{self._data_timeout} ms") # Await the data from the Server. frame = await socket.recv(copy=False) data = pickle.loads(frame.buffer) - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) + if isinstance(data, Exception): + # Re-raise exceptions returned by the server raise data - else: - raise ValueError(error_message) - return data + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data + else: + raise ValueError(error_message) + + return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, @@ -236,12 +240,9 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): frame = await socket.recv(copy=False) return pickle.loads(frame.buffer) - # Make a new socket connection. if socket is None: - with self.rpc_get_data_socket() as socket: + with self.get_data_socket() as socket: response = await do_rpc_call(socket, request) - - # Use existing socket connection. else: response = await do_rpc_call(socket, request) @@ -270,6 +271,13 @@ async def _wait_for_server_rpc(self): request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server") + async def _notify_ready(self): + """Get the RPCServer that the RPCClient is ready""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.CLIENT_IS_READY, + error_message="Unable to notify RPC Server of client readiness") + async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -371,21 +379,21 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - with self.rpc_get_data_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - ack: Frame = await socket.recv(copy=False) - if len(ack.buffer) != 0: - exception = pickle.loads(ack.buffer) - raise exception + + # Send RPCGenerateRequest to the RPCServer. + await self.new_req_socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) + + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception while not finished: request_output = await queue.get() @@ -395,7 +403,8 @@ async def generate( # possibly setting the `errored` property. if not self._errored: try: - await self.check_health(socket=socket) + # await self.check_health(socket=socket) + pass except Exception as e: self._errored = True logger.exception(repr(e)) diff --git a/vllm/utils.py b/vllm/utils.py index dab8e5fe0435..dd255684cd0a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -449,7 +449,6 @@ async def merge_async_iterators( It also optionally polls a provided function at least once per second to check for client cancellation. """ - # Can use anext() in python >= 3.10 awaits = { ensure_future(pair[1].__anext__()): pair From 3b8311bc70d64086fef3b034de95cf36af442eaa Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:03:04 +0000 Subject: [PATCH 006/116] added mp_llm_engine --- vllm/engine/mp_llm_engine.py | 119 +++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 vllm/engine/mp_llm_engine.py diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py new file mode 100644 index 000000000000..cb639021cc24 --- /dev/null +++ b/vllm/engine/mp_llm_engine.py @@ -0,0 +1,119 @@ +import zmq +import cloudpickle, pickle +from vllm.logger import init_logger +from vllm import EngineArgs, LLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, + RPCAbortRequest, + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) + +logger = init_logger(__name__) + +class MPLLMEngine: + def __init__(self, engine_args) -> None: + self.engine = LLMEngine.from_engine_args(engine_args) + + self.ctx = zmq.Context() + + self.new_req_socket = self.ctx.socket(zmq.constants.PULL) + self.new_req_socket.bind("ipc:///tmp/new_req_socket") + + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind("ipc:///tmp/output_socket") + + self.data_socket = self.ctx.socket(zmq.constants.ROUTER) + self.data_socket.bind("ipc:///tmp/data_socket") + + def run(self): + logger.info("Running Startup Loop.") + self.startup_loop() + logger.info("Running Engine Loop.") + self.engine_loop() + + def startup_loop(self): + client_is_ready = False + while not client_is_ready: + identity, message = self.data_socket.recv_multipart(copy=False) + request = cloudpickle.loads(message.buffer) + if request in [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG + ]: + config = self.get_config(request) + self.data_socket.send_multipart((identity, pickle.dumps(config)), copy=False) + elif request == RPCUtilityRequest.IS_SERVER_READY: + self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + self.data_socket.send_multipart((identity, pickle.dumps(self.engine.is_tracing_enabled())), copy=False) + elif request == RPCUtilityRequest.CLIENT_IS_READY: + self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) + client_is_ready = True + self.data_socket.close() + del self.data_socket + + def engine_loop(self): + has_requests_in_progress = False + while True: + if not has_requests_in_progress: + self.wait_for_new_requests() + has_requests_in_progress = self.engine_step() + + def engine_step(self): + self.add_new_requests() + request_outputs = self.engine.step() + self.send_request_outputs(request_outputs) + + all_finished = True + for request_output in request_outputs: + finished = request_output.finished + if not finished: + all_finished = False + break + + return not all_finished + + def send_request_outputs(self, request_outputs): + self.output_socket.send_multipart( + (pickle.dumps(request_outputs),), copy=False) + + def add_new_requests(self): + while self.new_req_socket.poll(timeout=0) != 0: + message = self.new_req_socket.recv(copy=False) + generate_rpc_request = pickle.loads(message.buffer) + self.engine.add_request( + request_id=generate_rpc_request.request_id, + inputs=generate_rpc_request.inputs, + params=generate_rpc_request.sampling_params, + lora_request=generate_rpc_request.lora_request, + trace_headers=generate_rpc_request.trace_headers, + prompt_adapter_request=generate_rpc_request.prompt_adapter_request, + ) + + def wait_for_new_requests(self): + while self.new_req_socket.poll(timeout=1000) == 0: + logger.info("Waiting for new requests...") + logger.info("Found new request!") + + def get_config(self, request): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + model_config = self.engine.get_model_config() + return model_config + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.engine.get_parallel_config() + else: + raise ValueError("Unknown Config Request: %s", request) + +def run_rpc_server(engine_args: EngineArgs): + engine = RPCLLMEngine(engine_args) + engine.run() From 5e2eb7449b3a414b4834c4bca616c7dd648a27e1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:05:35 +0000 Subject: [PATCH 007/116] fixed --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 34daf5bcb35f..cdba0a0ecc9a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -40,7 +40,7 @@ # yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient # from vllm.entrypoints.openai.rpc.server import run_rpc_server -from vllm.engine.llm_engine2 import run_rpc_server +from vllm.engine.mp_llm_engine import run_rpc_server from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding From aa62f2e4137f6dcb0bea7b250923d4752ea95028 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:06:37 +0000 Subject: [PATCH 008/116] format --- vllm/engine/mp_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py index cb639021cc24..ff376208ed02 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/mp_llm_engine.py @@ -115,5 +115,5 @@ def get_config(self, request): raise ValueError("Unknown Config Request: %s", request) def run_rpc_server(engine_args: EngineArgs): - engine = RPCLLMEngine(engine_args) + engine = MPLLMEngine(engine_args) engine.run() From 863081bc1320244f4c172b5312791d4a38c07e19 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:19:20 +0000 Subject: [PATCH 009/116] cleanup --- vllm/entrypoints/openai/rpc/server.py | 260 -------------------------- 1 file changed, 260 deletions(-) delete mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index 4b1ba19327b7..000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,260 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, - RPCOutputStreamRequest, - RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.ROUTER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.bind(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - self.engine.shutdown_background_loop() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def stream_outputs(self, identity): - # This runs indefinitely - #TODO handle shutdown - async for outputs in self.engine.global_output_generator(): - # Trim down contents to be equivalent to deltas (other PR for this) - # for output in outputs: - # output.prompt = None - # output.prompt_token_ids = None - # output.prompt_logprobs = None - # for o in output.outputs: - # o.token_ids = [0] - # o.text = " word" - - await self.socket.send_multipart((identity, pickle.dumps(outputs)), - copy=False) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - # Empty result to indicate success - result = b'' - try: - await self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - except Exception as e: - result = pickle.dumps(e) - - await self.socket.send_multipart((identity, result), copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCOutputStreamRequest): - return self.stream_outputs(identity) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # import pyinstrument - - # with pyinstrument.Profiler(async_mode="disabled") as prof: - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - # prof.write_html("prof-disabled.html", show_all=True) - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) From 965b97a9e18a83987414ec7b6d7cb083d6a83598 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:20:24 +0000 Subject: [PATCH 010/116] revert asyncllmengine --- vllm/engine/async_llm_engine.py | 214 +++++++++++++------------------- 1 file changed, 88 insertions(+), 126 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e4bc40150af1..159281dabde4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,11 +22,12 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -47,6 +48,7 @@ def _log_task_completion(task: asyncio.Task, there is an exception. """ + exception = None try: return_value = task.result() raise AssertionError( @@ -79,7 +81,8 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -121,11 +124,10 @@ def _is_raisable(value: Any): class RequestTracker: """Synchronous abstraction for tracking requests.""" - def __init__(self, per_request_streams: bool = True) -> None: - self._per_request_streams = per_request_streams + def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[Optional[AsyncStream], + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() @@ -159,16 +161,14 @@ def process_request_output(self, if finished: stream = self._request_streams.pop(request_id, None) - if stream is not None: - stream.finish() else: stream = self._request_streams.get(request_id) # Guard against a KeyError which can occur if the request was aborted # while the output was generated - # if stream is not None: - # stream.put(request_output) - # if finished: - # stream.finish() + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() if verbose and finished: logger.info("Finished request %s.", request_id) @@ -187,15 +187,14 @@ def add_request(self, request_id: str, *, verbose: bool = False, - **engine_add_request_kwargs) -> Optional[AsyncStream]: + **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background loop iteration.""" if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) \ - if self._per_request_streams else None + stream = AsyncStream(request_id, abort_request) self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs @@ -236,15 +235,13 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() - request_id = new_request["request_id"] + request_id = stream.request_id if request_id in finished_requests: # The request has already been aborted. - if stream is not None: - stream.finish(asyncio.CancelledError) + stream.finish(asyncio.CancelledError) finished_requests.discard(request_id) else: - if stream is not None: - self._request_streams[request_id] = stream + self._request_streams[request_id] = stream new_requests.append(new_request) return new_requests, finished_requests @@ -283,6 +280,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -293,17 +294,27 @@ async def step_async( # Clear outputs on scheduler iteration start ctx.request_outputs.clear() + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -315,9 +326,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -343,8 +351,13 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -354,7 +367,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -366,22 +379,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -394,7 +410,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) @@ -643,34 +663,7 @@ def __init__(self, self._errored_with: Optional[BaseException] = None # Lazy initialized fields - self._request_tracker: RequestTracker = None # type: ignore[assignment] - - self._global_queue: Optional[asyncio.Queue] = None - - async def global_output_generator( - self - ) -> AsyncGenerator[List[Union[RequestOutput, EmbeddingRequestOutput, - Tuple[str, BaseException]]], None]: - """Returns a single generator that streams outputs from all - requests. - - Must be called at most once prior to processing any requests, - and if used, generate() will return None rather than a per-request - stream. - """ - if self._global_queue is not None: - raise RuntimeError( - "global_output_generator can only be called once") - if self._request_tracker is not None: - raise RuntimeError( - "global_output_generator must be called before processing " - "any requests") - - self._global_queue = asyncio.Queue() - - # This runs until the engine is shut down - while True: - yield await self._global_queue.get() + self._request_tracker: RequestTracker @classmethod def _get_executor_cls( @@ -794,11 +787,6 @@ def set_errored(self, exc: Exception) -> None: def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) - if self._global_queue is not None: - #TODO clean this up - for request_id in tuple( - self._request_tracker._request_streams.keys()): - self._global_queue.put_nowait((request_id, exc)) async def get_tokenizer( self, @@ -819,8 +807,7 @@ def start_background_loop(self) -> None: if self.is_running: raise RuntimeError("Background loop is already running.") # Initialize the RequestTracker here so it uses the right event loop. - per_request_streams = self._global_queue is None - self._request_tracker = RequestTracker(per_request_streams) + self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) @@ -881,14 +868,11 @@ async def engine_step(self, virtual_engine: int) -> bool: await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use a vLLM specific error for failed validation - request_id = new_request["request_id"] self._request_tracker.process_exception( - request_id, + new_request["request_id"], e, verbose=self.log_requests, ) - if self._global_queue is not None: - self._global_queue.put_nowait((request_id, e)) if aborted_requests: await self._engine_abort(aborted_requests) @@ -899,18 +883,13 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - all_finished = True + finished = True for request_output in request_outputs: - finished = request_output.finished - if finished or self._global_queue is None: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests) - all_finished = all_finished and finished - - if self._global_queue is not None: - self._global_queue.put_nowait(request_outputs) + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + finished = finished and request_output.finished - return not all_finished + return not finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -995,9 +974,8 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Optional[AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], - None]]: + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -1018,7 +996,7 @@ async def add_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request) - return stream.generator() if stream is not None else None + return stream.generator() async def generate( self, @@ -1028,7 +1006,7 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> Optional[AsyncGenerator[RequestOutput, None]]: + ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -1050,9 +1028,6 @@ async def generate( The output `RequestOutput` objects from the LLMEngine for the request. - Unless a global output generator is being used, in which case - this methods will return None. - Details: - If the engine is not running, start the background loop, which iteratively invokes @@ -1096,23 +1071,15 @@ async def generate( >>> # Process and return the final output >>> ... """ - maybe_generator = await self.add_request( - request_id, - inputs, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - ) - return maybe_generator - if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: - return maybe_generator - - async def validating_generator(): - async for output in maybe_generator: - yield LLMEngine.validate_output(output, RequestOutput) - - return validating_generator() + async for output in await self.add_request( + request_id, + inputs, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + ): + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, @@ -1182,15 +1149,13 @@ async def encode( >>> # Process and return the final output >>> ... """ - generator = await self.add_request( - request_id, - inputs, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ) - assert generator is not None - async for output in generator: + async for output in await self.add_request( + request_id, + inputs, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def abort(self, request_id: str) -> None: @@ -1224,9 +1189,6 @@ def _abort(self, request_id: str) -> None: exception=asyncio.CancelledError, verbose=self.log_requests) - if self._global_queue is not None: - self._global_queue.put_nowait((request_id, asyncio.CancelledError)) - async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: From 8fd72f69ed8476c585b53b9d79972deb1084eda4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:23:39 +0000 Subject: [PATCH 011/116] fix nit --- vllm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/utils.py b/vllm/utils.py index dd255684cd0a..dab8e5fe0435 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -449,6 +449,7 @@ async def merge_async_iterators( It also optionally polls a provided function at least once per second to check for client cancellation. """ + # Can use anext() in python >= 3.10 awaits = { ensure_future(pair[1].__anext__()): pair From ddeb7c672f0b88d34c575d6fdb072cb709f6bc98 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:24:38 +0000 Subject: [PATCH 012/116] format --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 159281dabde4..6c7e2fdc7a6d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1281,4 +1281,4 @@ async def start_profile(self) -> None: self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + self.engine.model_executor._run_workers("stop_profile") \ No newline at end of file From 4b111e4eed882c6489493967009c62b922f44ee1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:25:18 +0000 Subject: [PATCH 013/116] clean --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6c7e2fdc7a6d..159281dabde4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1281,4 +1281,4 @@ async def start_profile(self) -> None: self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") \ No newline at end of file + self.engine.model_executor._run_workers("stop_profile") From a5ffd2c3ea387103a2bf733e24930a565c6e66ec Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:26:55 +0000 Subject: [PATCH 014/116] fix --- vllm/entrypoints/openai/rpc/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 4bf24bdc37f4..a99b6edcc65e 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -17,11 +17,6 @@ VLLM_RPC_ZMQ_HWM = 0 -@dataclass -class RPCOutputStreamRequest: - pass - - @dataclass class RPCGenerateRequest: inputs: PromptInputs From 139587264dadefe503fa51a0b48242b1c1267c90 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:57:24 +0000 Subject: [PATCH 015/116] stash --- vllm/engine/mp_llm_engine.py | 26 +++++--------------------- vllm/entrypoints/openai/rpc/client.py | 4 +--- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py index ff376208ed02..4c1ede7cedff 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/mp_llm_engine.py @@ -3,10 +3,6 @@ from vllm.logger import init_logger from vllm import EngineArgs, LLMEngine from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, - RPCAbortRequest, - RPCGenerateRequest, - RPCOutputStreamRequest, RPCUtilityRequest) logger = init_logger(__name__) @@ -57,25 +53,13 @@ def startup_loop(self): del self.data_socket def engine_loop(self): - has_requests_in_progress = False while True: - if not has_requests_in_progress: + if not self.engine.has_unfinished_requests(): self.wait_for_new_requests() - has_requests_in_progress = self.engine_step() - - def engine_step(self): - self.add_new_requests() - request_outputs = self.engine.step() - self.send_request_outputs(request_outputs) - - all_finished = True - for request_output in request_outputs: - finished = request_output.finished - if not finished: - all_finished = False - break - - return not all_finished + + self.add_new_requests() + request_outputs = self.engine.step() + self.send_request_outputs(request_outputs) def send_request_outputs(self, request_outputs): self.output_socket.send_multipart( diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index c71f25084422..a13e70e8f94d 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -15,11 +15,9 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, + RPCAbortRequest, RPCGenerateRequest, - RPCOutputStreamRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS From 938cf85bda9da1880d34f46791bff4362098c7a3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 22:06:03 +0000 Subject: [PATCH 016/116] move files --- .../openai => engine}/rpc/__init__.py | 6 -- .../openai => engine}/rpc/client.py | 69 +++---------------- .../rpc_llm_engine.py} | 6 +- 3 files changed, 13 insertions(+), 68 deletions(-) rename vllm/{entrypoints/openai => engine}/rpc/__init__.py (89%) rename vllm/{entrypoints/openai => engine}/rpc/client.py (82%) rename vllm/engine/{mp_llm_engine.py => rpc/rpc_llm_engine.py} (96%) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/engine/rpc/__init__.py similarity index 89% rename from vllm/entrypoints/openai/rpc/__init__.py rename to vllm/engine/rpc/__init__.py index a99b6edcc65e..387119a1b11e 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/engine/rpc/__init__.py @@ -10,12 +10,6 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - @dataclass class RPCGenerateRequest: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/engine/rpc/client.py similarity index 82% rename from vllm/entrypoints/openai/rpc/client.py rename to vllm/engine/rpc/client.py index a13e70e8f94d..2bcd12c4e2df 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/engine/rpc/client.py @@ -48,70 +48,21 @@ class RPCClientClosedError(Exception): class AsyncEngineRPCClient: """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. + xxx """ def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.new_req_socket.connect("ipc:///tmp/new_req_socket") + self.new_req_socket.connect(f"{rpc_path}_new_req_socket") self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect("ipc:///tmp/output_socket") + self.new_req_socket.connect(f"{rpc_path}_output_socket") - # self.data_socket: Socket = self.context.socket(zmq.constants.DEALER) - # self.data_socket.connect("ipc:///tmp/data_socket") + self.get_data_path = f"{rpc_path}_data_socket" - self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) @@ -120,8 +71,8 @@ def get_data_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. + # This can happen if a server shutdown is triggered but some + # coroutines are still running requests. # There should not be a race condition with this check because we don't # yield to the event loop between here and opening the socket. if self.context.closed: @@ -197,9 +148,9 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, copy=False) # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") # Await the data from the Server. frame = await socket.recv(copy=False) @@ -231,9 +182,9 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): await socket.send_multipart((cloudpickle.dumps(request), )) - if await socket.poll(timeout=self._data_timeout) == 0: + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") frame = await socket.recv(copy=False) return pickle.loads(frame.buffer) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/rpc/rpc_llm_engine.py similarity index 96% rename from vllm/engine/mp_llm_engine.py rename to vllm/engine/rpc/rpc_llm_engine.py index 4c1ede7cedff..4c6ec13134a0 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/rpc/rpc_llm_engine.py @@ -2,12 +2,12 @@ import cloudpickle, pickle from vllm.logger import init_logger from vllm import EngineArgs, LLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) +from vllm.engine.rpc import (VLLM_RPC_SUCCESS_STR, + RPCUtilityRequest) logger = init_logger(__name__) -class MPLLMEngine: +class RPCLLMEngine: def __init__(self, engine_args) -> None: self.engine = LLMEngine.from_engine_args(engine_args) From 72d1d4233cd24e66c72e7f5a7664a95c55942df2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 00:42:45 +0000 Subject: [PATCH 017/116] cleanup code --- vllm/engine/async_llm_engine.py | 5 - vllm/engine/protocol.py | 4 - vllm/engine/rpc/__init__.py | 45 --- vllm/engine/rpc/client.py | 396 -------------------------- vllm/engine/rpc/rpc_llm_engine.py | 103 ------- vllm/entrypoints/launcher.py | 9 - vllm/entrypoints/openai/api_server.py | 46 ++- 7 files changed, 22 insertions(+), 586 deletions(-) delete mode 100644 vllm/engine/rpc/__init__.py delete mode 100644 vllm/engine/rpc/client.py delete mode 100644 vllm/engine/rpc/rpc_llm_engine.py diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 159281dabde4..203f2f274891 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -776,11 +776,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None - def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8d..de6314d53219 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -29,10 +29,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - def generate( self, inputs: PromptInputs, diff --git a/vllm/engine/rpc/__init__.py b/vllm/engine/rpc/__init__.py deleted file mode 100644 index 387119a1b11e..000000000000 --- a/vllm/engine/rpc/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - CLIENT_IS_READY = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/engine/rpc/client.py b/vllm/engine/rpc/client.py deleted file mode 100644 index 2bcd12c4e2df..000000000000 --- a/vllm/engine/rpc/client.py +++ /dev/null @@ -1,396 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, - RPCGenerateRequest, - RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - xxx - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._errored = False - - self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.new_req_socket.connect(f"{rpc_path}_new_req_socket") - - self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.new_req_socket.connect(f"{rpc_path}_output_socket") - - self.get_data_path = f"{rpc_path}_data_socket" - - self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_handler = asyncio.create_task(self.run_output_handler()) - - @contextmanager - def get_data_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some - # coroutines are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - try: - socket.connect("ipc:///tmp/data_socket") - yield socket - finally: - socket.close(linger=0) - - async def run_output_handler(self): - # await self.socket.send_multipart( - # (cloudpickle.dumps(RPCOutputStreamRequest()), )) - - # Stream back the results from the RPC Server. - while True: - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - for output in request_outputs: - if isinstance(output, tuple): - # Exception case - request_id, output = output - else: - request_id = output.request_id - - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - await self._notify_ready() - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.context.destroy(linger=0) - - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.get_data_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - frame = await socket.recv(copy=False) - return pickle.loads(frame.buffer) - - if socket is None: - with self.get_data_socket() as socket: - response = await do_rpc_call(socket, request) - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _notify_ready(self): - """Get the RPCServer that the RPCClient is ready""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CLIENT_IS_READY, - error_message="Unable to notify RPC Server of client readiness") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() - self.output_queues[request_id] = queue - finished = False - try: - - # Send RPCGenerateRequest to the RPCServer. - await self.new_req_socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception - - while not finished: - request_output = await queue.get() - if isinstance(request_output, BaseException): - finished = True - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - # await self.check_health(socket=socket) - pass - except Exception as e: - self._errored = True - logger.exception(repr(e)) - raise request_output - - finished = request_output.finished - yield request_output - - finally: - self.output_queues.pop(request_id) - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/rpc/rpc_llm_engine.py b/vllm/engine/rpc/rpc_llm_engine.py deleted file mode 100644 index 4c6ec13134a0..000000000000 --- a/vllm/engine/rpc/rpc_llm_engine.py +++ /dev/null @@ -1,103 +0,0 @@ -import zmq -import cloudpickle, pickle -from vllm.logger import init_logger -from vllm import EngineArgs, LLMEngine -from vllm.engine.rpc import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) - -logger = init_logger(__name__) - -class RPCLLMEngine: - def __init__(self, engine_args) -> None: - self.engine = LLMEngine.from_engine_args(engine_args) - - self.ctx = zmq.Context() - - self.new_req_socket = self.ctx.socket(zmq.constants.PULL) - self.new_req_socket.bind("ipc:///tmp/new_req_socket") - - self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind("ipc:///tmp/output_socket") - - self.data_socket = self.ctx.socket(zmq.constants.ROUTER) - self.data_socket.bind("ipc:///tmp/data_socket") - - def run(self): - logger.info("Running Startup Loop.") - self.startup_loop() - logger.info("Running Engine Loop.") - self.engine_loop() - - def startup_loop(self): - client_is_ready = False - while not client_is_ready: - identity, message = self.data_socket.recv_multipart(copy=False) - request = cloudpickle.loads(message.buffer) - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - config = self.get_config(request) - self.data_socket.send_multipart((identity, pickle.dumps(config)), copy=False) - elif request == RPCUtilityRequest.IS_SERVER_READY: - self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - self.data_socket.send_multipart((identity, pickle.dumps(self.engine.is_tracing_enabled())), copy=False) - elif request == RPCUtilityRequest.CLIENT_IS_READY: - self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) - client_is_ready = True - self.data_socket.close() - del self.data_socket - - def engine_loop(self): - while True: - if not self.engine.has_unfinished_requests(): - self.wait_for_new_requests() - - self.add_new_requests() - request_outputs = self.engine.step() - self.send_request_outputs(request_outputs) - - def send_request_outputs(self, request_outputs): - self.output_socket.send_multipart( - (pickle.dumps(request_outputs),), copy=False) - - def add_new_requests(self): - while self.new_req_socket.poll(timeout=0) != 0: - message = self.new_req_socket.recv(copy=False) - generate_rpc_request = pickle.loads(message.buffer) - self.engine.add_request( - request_id=generate_rpc_request.request_id, - inputs=generate_rpc_request.inputs, - params=generate_rpc_request.sampling_params, - lora_request=generate_rpc_request.lora_request, - trace_headers=generate_rpc_request.trace_headers, - prompt_adapter_request=generate_rpc_request.prompt_adapter_request, - ) - - def wait_for_new_requests(self): - while self.new_req_socket.poll(timeout=1000) == 0: - logger.info("Waiting for new requests...") - logger.info("Found new request!") - - def get_config(self, request): - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - model_config = self.engine.get_model_config() - return model_config - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - return self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - return self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - return self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - return self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - -def run_rpc_server(engine_args: EngineArgs): - engine = MPLLMEngine(engine_args) - engine.run() diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb..f4a9c61a431c 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -27,15 +27,6 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server, engine) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cdba0a0ecc9a..6dcbbd433596 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -38,9 +38,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -# from vllm.entrypoints.openai.rpc.server import run_rpc_server -from vllm.engine.mp_llm_engine import run_rpc_server +from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -157,38 +156,37 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Build RPCClient, which conforms to AsyncEngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) - async_engine_client = rpc_client # type: ignore + mp_engine_client = MPEngineClient(ipc_path) + async_engine_client = mp_engine_client # type: ignore - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - # rpc_server_process = context.Process( - # target=run_rpc_server, - # args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - - rpc_server_process = context.Process(target=run_rpc_server, args=(engine_args,)) - rpc_server_process.start() + context = multiprocessing.get_context("spawn") + + engine_process = context.Process( + target=run_mp_engine, + args=(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path)) + engine_process.start() logger.info("Started engine process with PID %d", - rpc_server_process.pid) + engine_process.pid) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): + if not engine_process.is_alive(): logger.error( - "RPCServer process died before responding " + "Engine process died before responding " "to readiness probe") yield None return @@ -196,20 +194,20 @@ async def build_async_engine_client_from_engine_args( yield async_engine_client finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() # Wait for server process to join - rpc_server_process.join() + engine_process.join() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) async_engine_client = None #TODO From fcdcfc921540cc3c115bb314187ddba5af17522f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 00:42:59 +0000 Subject: [PATCH 018/116] refactor, cleanup --- vllm/engine/multiprocessing/__init__.py | 51 +++ vllm/engine/multiprocessing/mp_client.py | 368 +++++++++++++++++++ vllm/engine/multiprocessing/mp_llm_engine.py | 253 +++++++++++++ 3 files changed, 672 insertions(+) create mode 100644 vllm/engine/multiprocessing/__init__.py create mode 100644 vllm/engine/multiprocessing/mp_client.py create mode 100644 vllm/engine/multiprocessing/mp_llm_engine.py diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 000000000000..c6ecb6aa7545 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +# Success string used for RPC instructions. +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCAbortRequest: + request_id: str + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + IS_SERVER_HEALTHY = 8 + IS_TRACING_ENABLED = 9 + START_PROFILE = 10 + STOP_PROFILE = 11 + CLIENT_IS_READY = 11 + + +RPC_COFNIG_REQUEST = [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG +] + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py new file mode 100644 index 000000000000..fd3011d54888 --- /dev/null +++ b/vllm/engine/multiprocessing/mp_client.py @@ -0,0 +1,368 @@ +import asyncio +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) +from uuid import uuid4 + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf: disable +from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, + VLLM_RPC_SUCCESS_STR, + RPCAbortRequest, + RPCGenerateRequest, + RPCUtilityRequest) +# yapf: enable +from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MPEngineClient: + + def __init__(self, ipc_path: str): + self.context = zmq.asyncio.Context() + self._errored = False + + # Send RPCGenerateRequest to the MPLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}_input_socket") + + # Recieve streams of RequestOutput from the MPLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}_output_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_handler = asyncio.create_task(self.run_output_handler()) + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_output_handler(self): + # Stream lists of RequestOutput from MPLLMEngine. + while True: + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self._wait_for_server_rpc() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + + await self._notify_ready() + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets associated with this context and + # then terminate the context. + self.context.destroy(linger=0) + + + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + with self.get_data_socket() as socket: + # Ping RPCServer with a request. + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) + + # Make sure the server responds + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): + """Send one-way RPC request to trigger an action.""" + + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): + + await socket.send_multipart((cloudpickle.dumps(request), )) + + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + frame = await socket.recv(copy=False) + return pickle.loads(frame.buffer) + + if socket is None: + with self.get_data_socket() as socket: + response = await do_rpc_call(socket, request) + else: + response = await do_rpc_call(socket, request) + + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if isinstance(response, Exception): + logger.error(error_message) + raise response + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server") + + async def _notify_ready(self): + """Get the RPCServer that the RPCClient is ready""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.CLIENT_IS_READY, + error_message="Unable to notify RPC Server of client readiness") + + async def _get_model_config_rpc(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") + + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self) -> LoRAConfig: + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") + + async def _is_tracing_enabled_rpc(self) -> bool: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled from RPC Server") + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + # Suppress timeouts as well. + # In cases where the server is busy processing requests and a very + # large volume of abort requests arrive, it is likely that the server + # will not be able to ack all of them in time. We have seen this when + # we abort 20k requests at once while another 2k are processing- many + # of them time out, but we see the server successfully abort all of the + # requests. + # In this case we assume that the server has received or will receive + # these abort requests, and ignore the timeout. This prevents a massive + # wall of `TimeoutError` stack traces. + with suppress(RPCClientClosedError, TimeoutError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + + @property + def is_running(self) -> bool: + return not self._errored + + @property + def is_stopped(self) -> bool: + return self._errored + + @property + def errored(self) -> bool: + return self._errored + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + finished = False + try: + + # Send RPCGenerateRequest to the RPCServer. + await self.input_socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) + + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception + + while not finished: + request_output = await queue.get() + if isinstance(request_output, BaseException): + finished = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + # await self.check_health(socket=socket) + pass + except Exception as e: + self._errored = True + logger.exception(repr(e)) + raise request_output + + finished = request_output.finished + yield request_output + + finally: + self.output_queues.pop(request_id) + # Request was canceled by the client. + if not finished and not self._errored: + await self.abort(request_id) + + async def check_health(self, socket: Optional[Socket] = None) -> None: + """Raise if unhealthy""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_HEALTHY, + error_message="Got Unhealthy response from RPC Server", + socket=socket) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.START_PROFILE, + error_message="RPCRequest START_PROFILE failed.") + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.STOP_PROFILE, + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py new file mode 100644 index 000000000000..0671c48d84c6 --- /dev/null +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -0,0 +1,253 @@ +import ray +import zmq +import cloudpickle +import pickle +from typing import Any, Type, Union, Iterator +from contextlib import contextmanager + +import vllm.envs as envs +from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, + RPCUtilityRequest) +from vllm.utils import print_warning_once +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +class MPLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in asynchronous manner. It runs a background loop and uses zeromq to + recieve new requests and stream outputs incrementally to another process. + + The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest + is recieved by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()` and sends the RequestOutputs back over + the output_socket. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + async_engine_args: AsyncLLMEngine args + log_requests: Whether to log the requests. + """ + + _engine_class: Type[LLMEngine] = LLMEngine + + def __init__(self, + worker_use_ray: bool, + engine_use_ray: bool, + *args, + ipc_path: str, + log_requests: bool = True, + **kwargs) -> None: + + if engine_use_ray: + raise NotImplementedError("Not yet supported.") + + self.worker_use_ray = worker_use_ray + self.engine_use_ray = engine_use_ray + self.log_requests = log_requests + self.engine = self._init_engine(*args, **kwargs) + + if self.engine_use_ray: + print_warning_once( + "DEPRECATED. `--engine-use-ray` is deprecated and will " + "be removed in a future update. " + "See https://github.com/vllm-project/vllm/issues/7045.") + + if envs.VLLM_ALLOW_ENGINE_USE_RAY: + print_warning_once( + "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray") + else: + raise ValueError("`--engine-use-ray` is deprecated. " + "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to " + "force use it") + + self.ctx = zmq.Context() + + # Recieve RPCGenerateRequest from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}_input_socket") + + # Send streams of RequestOutput back to Client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}_output_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an RPCLLM engine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + if engine_args.engine_use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + # TODO: better abstraction? + executor_class = AsyncLLMEngine._get_executor_cls(engine_config) + + return cls( + executor_class.uses_ray, + engine_args.engine_use_ray, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + ipc_path=ipc_path, + ) + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + self.input_socket.close() + self.output_socket.close() + self.ctx.destroy(linger=0) + del self.engine + + def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: + """Initialize the LLMEngine""" + + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + # FIXME(woosuk): This is a bit hacky. Be careful when changing the + # order of the arguments. + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] + if (parallel_config.tensor_parallel_size == 1 + and parallel_config.pipeline_parallel_size == 1): + num_gpus = cache_config.gpu_memory_utilization + else: + num_gpus = 1 + engine_class = ray.remote(num_gpus=num_gpus)( + self._engine_class).remote + return engine_class(*args, **kwargs) + + def run_background_loop(self): + """Entrypoint that kicks off the background processing loop.""" + + # Allow RPCClient to query data in startup phase. + self.run_startup_loop() + + # Kick off core processing loop. + self.run_engine_loop() + + @contextmanager + def make_data_socket(self) -> Iterator[zmq.Socket]: + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Loop over startup RPCRequests from RPCClient.""" + + with self.make_data_socket() as socket: + + # Loop until the RPCClient has all the data it needs. + client_is_ready = False + while not client_is_ready: + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCUtilityRequest = cloudpickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + response = self.engine.get_model_config() + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + response = self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + response = self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + response = self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + response = self.engine.get_parallel_config() + elif request == RPCUtilityRequest.IS_SERVER_READY: + response = VLLM_RPC_SUCCESS_STR + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + response = self.engine.is_tracing_enabled() + elif request == RPCUtilityRequest.CLIENT_IS_READY: + response = VLLM_RPC_SUCCESS_STR + # Once client ready, breakout of loop. + client_is_ready = True + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + socket.send_multipart( + (identity, pickle.dumps(response)), copy=False) + + except Exception as e: + socket.send_multipart((identity, pickle.dumps(e)), copy=False) + + def run_engine_loop(self) -> None: + # TODO: handle PP + + while True: + # Block until there is a new request. + if not self.engine.has_unfinished_requests(): + self.wait_for_new_requests() + + # Add new work from input socket. + self.maybe_add_new_requests() + + # Engine step. + request_outputs = self.engine.step() + + # Stream results to output socket. + self.stream_outputs(request_outputs) + + + def wait_for_new_requests(self): + while self.input_socket.poll(timeout=10000) == 0: + logger.debug("Waiting for new request.") + + def stream_outputs(self, request_outputs): + self.output_socket.send_multipart( + (pickle.dumps(request_outputs),), copy=False) + + def maybe_add_new_requests(self): + while self.input_socket.poll(timeout=0) != 0: + message = self.input_socket.recv(copy=False) + generate_rpc_request = pickle.loads(message.buffer) + self.engine.add_request( + request_id=generate_rpc_request.request_id, + inputs=generate_rpc_request.inputs, + params=generate_rpc_request.sampling_params, + lora_request=generate_rpc_request.lora_request, + trace_headers=generate_rpc_request.trace_headers, + prompt_adapter_request=generate_rpc_request.prompt_adapter_request, + ) + + +def run_mp_engine(engine_args: AsyncEngineArgs, + usage_context: UsageContext, + ipc_path: str): + engine = MPLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + + engine.run_background_loop() From 659169ee8290812e1e32d89f9bed33a9ea8fe196 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 01:42:58 +0000 Subject: [PATCH 019/116] updated --- examples/openai_completion_client.py | 2 +- vllm/engine/multiprocessing/__init__.py | 22 +- vllm/engine/multiprocessing/mp_client.py | 237 ++++++++++--------- vllm/engine/multiprocessing/mp_llm_engine.py | 116 ++++----- vllm/entrypoints/openai/api_server.py | 5 +- 5 files changed, 195 insertions(+), 187 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 13f98d322036..0b77ed4d2558 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -19,7 +19,7 @@ model=model, prompt="A robot may not injure a human being", stream=stream, - max_tokens=1000) + max_tokens=100) print("Completion results:") if stream: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index c6ecb6aa7545..be7d80072f96 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -25,27 +25,19 @@ class RPCAbortRequest: request_id: str class RPCUtilityRequest(Enum): + DO_LOG_STATS = 1 + CHECK_HEALTH = 2 + +class RPCStartupRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 GET_DECODING_CONFIG = 3 GET_PARALLEL_CONFIG = 4 GET_SCHEDULER_CONFIG = 5 GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - CLIENT_IS_READY = 11 - - -RPC_COFNIG_REQUEST = [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG -] + GET_TRACING_ENABLED = 7 + CLIENT_IS_READY = 8 + RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index fd3011d54888..086242d28fb5 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -1,9 +1,8 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Union) -from uuid import uuid4 import cloudpickle import zmq @@ -18,6 +17,7 @@ VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, + RPCStartupRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS @@ -31,6 +31,15 @@ logger = init_logger(__name__) +class MPClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ class MPEngineClient: @@ -82,24 +91,27 @@ async def run_output_handler(self): async def setup(self): """Setup the client before it starts sending server requests.""" - # Wait until server is ready. - await self._wait_for_server_rpc() + with self.get_data_socket() as socket: + + # Wait until server is ready. + await self._wait_for_server_rpc(socket) - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() + # Get the configs. + self.model_config = await self._get_model_config_rpc(socket) + self.decoding_config = await self._get_decoding_config_rpc(socket) + self.tracing_flag = await self._is_tracing_enabled_rpc(socket) - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc(socket)), + parallel_config=(await self._get_parallel_config_rpc(socket)), + enable_lora=bool(await self._get_lora_config_rpc(socket)), + ) - await self._notify_ready() + # Notify MPLLMEngine client is ready to start sending requests. + await self._notify_ready(socket) def close(self): """Destroy the ZeroMQ Context.""" @@ -110,64 +122,63 @@ def close(self): async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, - error_message: str) -> Any: + error_message: str, + socket: Socket) -> Any: """Send an RPC request that is expecting data back.""" - with self.get_data_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server + # Ping RPCServer with a request. + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) + + # Make sure the server responds + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) raise data + else: + raise ValueError(error_message) - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data + return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): + socket: Socket): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) + await socket.send_multipart((cloudpickle.dumps(request), )) - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - frame = await socket.recv(copy=False) - return pickle.loads(frame.buffer) - - if socket is None: - with self.get_data_socket() as socket: - response = await do_rpc_call(socket, request) - else: - response = await do_rpc_call(socket, request) + # TODO: is there a way to ack this if we are using the input_socket? + # I don't think so, b/c we are using PUSH/PULL + + async def _awk_one_way_rpc_request(self, + timeout: int, + expected_str: str, + error_message: str, + socket: Socket,): + if await socket.poll(timeout=timeout) == 0: + raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") + + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if not isinstance(response, str) or response != expected_str: if isinstance(response, Exception): logger.error(error_message) raise response @@ -185,72 +196,86 @@ async def get_model_config(self) -> ModelConfig: async def is_tracing_enabled(self) -> bool: return self.tracing_flag - async def _wait_for_server_rpc(self): + async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" + + # Readiness probe. + request = RPCStartupRequest.IS_SERVER_READY + await socket.send_multipart((cloudpickle.dumps(request), )) + + # Raises TimeoutError if not awk, causing a retry. + await self._awk_one_way_rpc_request( + expected_str=VLLM_RPC_SUCCESS_STR, + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + error_message="Unable to start RPC Server", + socket=socket) + - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _notify_ready(self): + async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CLIENT_IS_READY, - error_message="Unable to notify RPC Server of client readiness") + request=RPCStartupRequest.CLIENT_IS_READY, + socket=socket) - async def _get_model_config_rpc(self) -> ModelConfig: + async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, + RPCStartupRequest.GET_MODEL_CONFIG, expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") + error_message="Could not get ModelConfig from RPC Server", + socket=socket) - async def _get_decoding_config_rpc(self) -> DecodingConfig: + async def _get_decoding_config_rpc(self, socket: Socket) -> DecodingConfig: """Get DecodingConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, + RPCStartupRequest.GET_DECODING_CONFIG, expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") + error_message="Could not get DecodingConfig from RPC Server", + socket=socket) - async def _get_parallel_config_rpc(self) -> ParallelConfig: + async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: """Get ParallelConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCStartupRequest.GET_PARALLEL_CONFIG, expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") + error_message="Could not get ParallelConfig from RPC Server", + socket=socket) - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + async def _get_scheduler_config_rpc(self, socket: Socket) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCStartupRequest.GET_SCHEDULER_CONFIG, expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") + error_message="Could not get SchedulerConfig from RPC Server", + socket=socket) - async def _get_lora_config_rpc(self) -> LoRAConfig: + async def _get_lora_config_rpc(self, socket: Socket) -> LoRAConfig: """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, + RPCStartupRequest.GET_LORA_CONFIG, expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") + error_message="Could not get LoRAConfig from RPC Server", + socket=socket) - async def _is_tracing_enabled_rpc(self) -> bool: + async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: """Get is_tracing_enabled flag from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, + RPCStartupRequest.GET_TRACING_ENABLED, expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") + error_message="Could not get is_tracing_enabled from RPC Server", + socket=socket) async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - # Suppress timeouts as well. + # Suppress timeouts and MPClientClosedError. # In cases where the server is busy processing requests and a very # large volume of abort requests arrive, it is likely that the server # will not be able to ack all of them in time. We have seen this when @@ -260,17 +285,17 @@ async def abort(self, request_id: str): # In this case we assume that the server has received or will receive # these abort requests, and ignore the timeout. This prevents a massive # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): + with suppress(MPClientClosedError, TimeoutError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") + socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): + with suppress(MPClientClosedError): await self._send_one_way_rpc_request( request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") + socket=self.input_socket) @property def is_running(self) -> bool: @@ -340,29 +365,15 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, socket: Optional[Socket] = None) -> None: + async def check_health(self) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) + request=RPCUtilityRequest.CHECK_HEALTH, + socket=self.input_socket) + async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 0671c48d84c6..6323b1d0734f 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -2,17 +2,19 @@ import zmq import cloudpickle import pickle -from typing import Any, Type, Union, Iterator +from typing import Iterator, List, Type, Union from contextlib import contextmanager -import vllm.envs as envs from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) -from vllm.utils import print_warning_once + RPCGenerateRequest, + RPCAbortRequest, + RPCStartupRequest, + RPCUtilityRequest) +from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, @@ -64,27 +66,13 @@ def __init__(self, self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) - if self.engine_use_ray: - print_warning_once( - "DEPRECATED. `--engine-use-ray` is deprecated and will " - "be removed in a future update. " - "See https://github.com/vllm-project/vllm/issues/7045.") - - if envs.VLLM_ALLOW_ENGINE_USE_RAY: - print_warning_once( - "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray") - else: - raise ValueError("`--engine-use-ray` is deprecated. " - "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to " - "force use it") - self.ctx = zmq.Context() - # Recieve RPCGenerateRequest from the client. + # Recieve input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.bind(f"{ipc_path}_input_socket") - # Send streams of RequestOutput back to Client. + # Send output stream back to client. self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}_output_socket") @@ -144,6 +132,7 @@ def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: self._engine_class).remote return engine_class(*args, **kwargs) + def run_background_loop(self): """Entrypoint that kicks off the background processing loop.""" @@ -152,7 +141,8 @@ def run_background_loop(self): # Kick off core processing loop. self.run_engine_loop() - + + @contextmanager def make_data_socket(self) -> Iterator[zmq.Socket]: socket = self.ctx.socket(zmq.constants.ROUTER) @@ -163,7 +153,7 @@ def make_data_socket(self) -> Iterator[zmq.Socket]: socket.close(linger=0) def run_startup_loop(self) -> None: - """Loop over startup RPCRequests from RPCClient.""" + """Loop over startup RPCStatupRequest from RPCClient.""" with self.make_data_socket() as socket: @@ -172,29 +162,27 @@ def run_startup_loop(self) -> None: while not client_is_ready: try: identity, message = socket.recv_multipart(copy=False) - request: RPCUtilityRequest = cloudpickle.loads(message.buffer) + request: RPCStartupRequest = pickle.loads(message.buffer) # Handle the query from the Client. - if request == RPCUtilityRequest.GET_MODEL_CONFIG: + if request == RPCStartupRequest.GET_MODEL_CONFIG: response = self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + elif request == RPCStartupRequest.GET_DECODING_CONFIG: response = self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: + elif request == RPCStartupRequest.GET_LORA_CONFIG: response = self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + elif request == RPCStartupRequest.GET_SCHEDULER_CONFIG: response = self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + elif request == RPCStartupRequest.GET_PARALLEL_CONFIG: response = self.engine.get_parallel_config() - elif request == RPCUtilityRequest.IS_SERVER_READY: - response = VLLM_RPC_SUCCESS_STR - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + elif request == RPCStartupRequest.GET_TRACING_ENABLED: response = self.engine.is_tracing_enabled() - elif request == RPCUtilityRequest.CLIENT_IS_READY: + elif request == RPCStartupRequest.IS_SERVER_READY: + response = VLLM_RPC_SUCCESS_STR + elif request == RPCStartupRequest.CLIENT_IS_READY: response = VLLM_RPC_SUCCESS_STR - # Once client ready, breakout of loop. + # Breakout of loop once client is ready. client_is_ready = True - else: - raise ValueError(f"Unknown RPCRequest: {request}") socket.send_multipart( (identity, pickle.dumps(response)), copy=False) @@ -203,43 +191,61 @@ def run_startup_loop(self) -> None: socket.send_multipart((identity, pickle.dumps(e)), copy=False) def run_engine_loop(self) -> None: - # TODO: handle PP - while True: # Block until there is a new request. if not self.engine.has_unfinished_requests(): - self.wait_for_new_requests() + self.wait_for_new_input() - # Add new work from input socket. - self.maybe_add_new_requests() + # Handle any new input from the input socket. + self.maybe_handle_new_input() # Engine step. request_outputs = self.engine.step() # Stream results to output socket. - self.stream_outputs(request_outputs) - + self.stream_outputs(request_outputs) - def wait_for_new_requests(self): + def wait_for_new_input(self): while self.input_socket.poll(timeout=10000) == 0: logger.debug("Waiting for new request.") - def stream_outputs(self, request_outputs): + def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart( (pickle.dumps(request_outputs),), copy=False) - - def maybe_add_new_requests(self): + + def maybe_handle_new_input(self): + """Handle new input with non-blocking IO""" while self.input_socket.poll(timeout=0) != 0: message = self.input_socket.recv(copy=False) - generate_rpc_request = pickle.loads(message.buffer) - self.engine.add_request( - request_id=generate_rpc_request.request_id, - inputs=generate_rpc_request.inputs, - params=generate_rpc_request.sampling_params, - lora_request=generate_rpc_request.lora_request, - trace_headers=generate_rpc_request.trace_headers, - prompt_adapter_request=generate_rpc_request.prompt_adapter_request, - ) + request = cloudpickle.loads(message.buffer) + + if isinstance(request, RPCGenerateRequest): + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUtilityRequest): + self._handle_utility_request(request) + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + def _handle_generate_request(self, request: RPCGenerateRequest): + self.engine.add_request( + request_id=request.request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + ) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request([request.request_id]) + + def _handle_utility_request(self, request: RPCUtilityRequest): + if request == RPCUtilityRequest.DO_LOG_STATS: + self.engine.do_log_stats() + elif request == RPCUtilityRequest.CHECK_HEALTH: + self.engine.check_health() def run_mp_engine(engine_args: AsyncEngineArgs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6dcbbd433596..ef8f98a3889b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -82,11 +82,10 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10) + await asyncio.sleep(1.) await async_engine_client.do_log_stats() - # if not engine_args.disable_log_stats: - if False: + if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) From 9886f3dc689e52c62ceea87723af3858b710f5f3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 01:54:40 +0000 Subject: [PATCH 020/116] make health check work --- vllm/engine/multiprocessing/mp_client.py | 20 ++++++++++++++++++-- vllm/engine/multiprocessing/mp_llm_engine.py | 13 +++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 086242d28fb5..eff3b0d06e40 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -55,6 +55,10 @@ def __init__(self, ipc_path: str): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") + # IPC path for awk of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}_health_socket") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" @@ -164,7 +168,8 @@ async def _send_one_way_rpc_request(self, await socket.send_multipart((cloudpickle.dumps(request), )) # TODO: is there a way to ack this if we are using the input_socket? - # I don't think so, b/c we are using PUSH/PULL + # I don't think so, b/c we are using PUSH/PULL w/out identities so no + # way to preserve order. async def _awk_one_way_rpc_request(self, timeout: int, @@ -349,7 +354,7 @@ async def generate( # possibly setting the `errored` property. if not self._errored: try: - # await self.check_health(socket=socket) + await self.check_health() pass except Exception as e: self._errored = True @@ -371,6 +376,17 @@ async def check_health(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) + + # Await awknoledgement from MPLLMEngine. + # Note: these requests are not necessarily serial. + # I.e. if two clients A, B send CHECK_HEALTH, the + # response to A could actually be the call send by B. + # TODO: is this bad? + await self._awk_one_way_rpc_request( + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + expected_str=VLLM_RPC_SUCCESS_STR, + error_message="Check health timeout.", + socket=self.health_socket) async def encode(self, *args, diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 6323b1d0734f..8ac1ade81316 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -76,6 +76,10 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}_output_socket") + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}_health_socket") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" @@ -213,6 +217,10 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart( (pickle.dumps(request_outputs),), copy=False) + def awk_check_health(self): + self.health_socket.send_multipart( + (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) + def maybe_handle_new_input(self): """Handle new input with non-blocking IO""" while self.input_socket.poll(timeout=0) != 0: @@ -246,8 +254,9 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() - - + # Special check_health channel for awk check health. + self.awk_check_health() + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From 5b2f0577fdbe5bf0f86e50297f2c57254f95f7c2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:46:43 +0000 Subject: [PATCH 021/116] format --- benchmarks/benchmark_throughput_async.py | 22 ++--- vllm/engine/multiprocessing/__init__.py | 2 +- vllm/engine/multiprocessing/mp_client.py | 76 ++++++++---------- vllm/engine/multiprocessing/mp_llm_engine.py | 84 ++++++++++---------- vllm/entrypoints/openai/api_server.py | 23 +++--- 5 files changed, 100 insertions(+), 107 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index 54eed0f4de78..217f11d14d30 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -1,6 +1,5 @@ """Benchmark offline inference throughput.""" import argparse -import asyncio import json import random import time @@ -12,11 +11,11 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.entrypoints.openai.api_server import build_async_engine_client_from_engine_args -from vllm.utils import merge_async_iterators -from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def sample_requests( @@ -92,7 +91,7 @@ async def run_vllm( load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, ) -> float: - from vllm import LLM, SamplingParams + from vllm import SamplingParams engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, @@ -123,8 +122,8 @@ async def run_vllm( decoupled = True - async with build_async_engine_client_from_engine_args(engine_args, - not decoupled) as llm: + async with build_async_engine_client_from_engine_args( + engine_args, not decoupled) as llm: # Add the requests to the engine. prompts: List[str] = [] @@ -146,13 +145,14 @@ async def run_vllm( for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): # generator = await llm.generate(prompt, sp, request_id=f"test{i}") generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) + generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass end = time.perf_counter() return end - start + def run_hf( requests: List[Tuple[str, int, int]], model: str, @@ -248,7 +248,7 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - coro = run_vllm( + coro = run_vllm( requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, @@ -260,7 +260,7 @@ def main(args: argparse.Namespace): args.use_v2_block_manager, args.download_dir, args.load_format, args.disable_async_output_proc) - elapsed_time = uvloop.run(coro) + elapsed_time = uvloop.run(coro) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index be7d80072f96..cf566933801e 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -40,4 +40,4 @@ class RPCStartupRequest(Enum): RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] + RPCUtilityRequest, RPCStartupRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index eff3b0d06e40..ba3269c252ba 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -1,8 +1,8 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, - Union) +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union) import cloudpickle import zmq @@ -14,10 +14,8 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, - VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, - RPCGenerateRequest, - RPCStartupRequest, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS @@ -31,6 +29,7 @@ logger = init_logger(__name__) + class MPClientClosedError(Exception): """Exception class raised when the client is used post-close. @@ -41,6 +40,7 @@ class MPClientClosedError(Exception): So, we throw this error such that we can suppress it. """ + class MPEngineClient: def __init__(self, ipc_path: str): @@ -51,7 +51,7 @@ def __init__(self, ipc_path: str): self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) self.input_socket.connect(f"{ipc_path}_input_socket") - # Recieve streams of RequestOutput from the MPLLMEngine. + # Receive streams of RequestOutput from the MPLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") @@ -65,7 +65,7 @@ def __init__(self, ipc_path: str): # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) - + @contextmanager def get_data_socket(self) -> Iterator[Socket]: socket = self.context.socket(zmq.constants.DEALER) @@ -79,7 +79,7 @@ async def run_output_handler(self): # Stream lists of RequestOutput from MPLLMEngine. while True: message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) + request_outputs: List[RequestOutput] = pickle.loads(message.buffer) for output in request_outputs: if isinstance(output, tuple): @@ -109,7 +109,8 @@ async def setup(self): # TODO: refactor OAI server to avoid needing this info. self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc(socket)), + scheduler_config=(await + self._get_scheduler_config_rpc(socket)), parallel_config=(await self._get_parallel_config_rpc(socket)), enable_lora=bool(await self._get_lora_config_rpc(socket)), ) @@ -123,22 +124,19 @@ def close(self): # then terminate the context. self.context.destroy(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, socket: Socket) -> Any: """Send an RPC request that is expecting data back.""" # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) + await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) # Make sure the server responds if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") # Await the data from the Server. frame = await socket.recv(copy=False) @@ -160,8 +158,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, socket: Socket): """Send one-way RPC request to trigger an action.""" @@ -170,16 +167,17 @@ async def _send_one_way_rpc_request(self, # TODO: is there a way to ack this if we are using the input_socket? # I don't think so, b/c we are using PUSH/PULL w/out identities so no # way to preserve order. - - async def _awk_one_way_rpc_request(self, - timeout: int, - expected_str: str, - error_message: str, - socket: Socket,): + + async def _awk_one_way_rpc_request( + self, + timeout: int, + expected_str: str, + error_message: str, + socket: Socket, + ): if await socket.poll(timeout=timeout) == 0: raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") - - + frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) @@ -203,7 +201,7 @@ async def is_tracing_enabled(self) -> bool: async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" - + # Readiness probe. request = RPCStartupRequest.IS_SERVER_READY await socket.send_multipart((cloudpickle.dumps(request), )) @@ -214,14 +212,12 @@ async def _wait_for_server_rpc(self, socket: Socket): timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, error_message="Unable to start RPC Server", socket=socket) - async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCStartupRequest.CLIENT_IS_READY, - socket=socket) + request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -250,7 +246,8 @@ async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: error_message="Could not get ParallelConfig from RPC Server", socket=socket) - async def _get_scheduler_config_rpc(self, socket: Socket) -> SchedulerConfig: + async def _get_scheduler_config_rpc(self, + socket: Socket) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( @@ -292,8 +289,7 @@ async def abort(self, request_id: str): # wall of `TimeoutError` stack traces. with suppress(MPClientClosedError, TimeoutError): await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - socket=self.input_socket) + request=RPCAbortRequest(request_id), socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" @@ -330,7 +326,7 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - + # Send RPCGenerateRequest to the RPCServer. await self.input_socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( @@ -341,10 +337,10 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception while not finished: request_output = await queue.get() @@ -374,8 +370,7 @@ async def check_health(self) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CHECK_HEALTH, - socket=self.input_socket) + request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) # Await awknoledgement from MPLLMEngine. # Note: these requests are not necessarily serial. @@ -387,7 +382,6 @@ async def check_health(self) -> None: expected_str=VLLM_RPC_SUCCESS_STR, error_message="Check health timeout.", socket=self.health_socket) - async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 8ac1ade81316..72ced337a160 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -1,19 +1,18 @@ -import ray -import zmq -import cloudpickle import pickle -from typing import Iterator, List, Type, Union from contextlib import contextmanager +from typing import Iterator, List, Type, Union -from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine +import cloudpickle +import ray +import zmq + +from vllm import AsyncEngineArgs, AsyncLLMEngine, LLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.logger import init_logger -from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, - RPCGenerateRequest, - RPCAbortRequest, - RPCStartupRequest, +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) +from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -22,15 +21,16 @@ logger = init_logger(__name__) + class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to enable use in asynchronous manner. It runs a background loop and uses zeromq to - recieve new requests and stream outputs incrementally to another process. + receive new requests and stream outputs incrementally to another process. The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest - is recieved by the input_socket. + is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal @@ -60,15 +60,15 @@ def __init__(self, if engine_use_ray: raise NotImplementedError("Not yet supported.") - + self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) - self.ctx = zmq.Context() + self.ctx = zmq.Context() # type: ignore[attr-defined] - # Recieve input from the client. + # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.bind(f"{ipc_path}_input_socket") @@ -84,10 +84,10 @@ def __init__(self, self.data_ipc_path = f"{ipc_path}_data_socket" @classmethod - def from_engine_args(cls, engine_args: AsyncEngineArgs, + def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): """Creates an RPCLLM engine from the engine arguments.""" - + engine_config = engine_args.create_engine_config() if engine_args.engine_use_ray: @@ -115,7 +115,8 @@ def cleanup(self): self.ctx.destroy(linger=0) del self.engine - def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: + def _init_engine(self, *args, + **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: """Initialize the LLMEngine""" if not self.engine_use_ray: @@ -135,20 +136,19 @@ def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: engine_class = ray.remote(num_gpus=num_gpus)( self._engine_class).remote return engine_class(*args, **kwargs) - def run_background_loop(self): """Entrypoint that kicks off the background processing loop.""" - - # Allow RPCClient to query data in startup phase. + + # Allow RPCClient to query data in startup phase. self.run_startup_loop() # Kick off core processing loop. self.run_engine_loop() - @contextmanager - def make_data_socket(self) -> Iterator[zmq.Socket]: + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] socket = self.ctx.socket(zmq.constants.ROUTER) try: socket.bind(self.data_ipc_path) @@ -158,7 +158,7 @@ def make_data_socket(self) -> Iterator[zmq.Socket]: def run_startup_loop(self) -> None: """Loop over startup RPCStatupRequest from RPCClient.""" - + with self.make_data_socket() as socket: # Loop until the RPCClient has all the data it needs. @@ -187,12 +187,13 @@ def run_startup_loop(self) -> None: response = VLLM_RPC_SUCCESS_STR # Breakout of loop once client is ready. client_is_ready = True - - socket.send_multipart( - (identity, pickle.dumps(response)), copy=False) + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) except Exception as e: - socket.send_multipart((identity, pickle.dumps(e)), copy=False) + socket.send_multipart((identity, pickle.dumps(e)), + copy=False) def run_engine_loop(self) -> None: while True: @@ -202,10 +203,10 @@ def run_engine_loop(self) -> None: # Handle any new input from the input socket. self.maybe_handle_new_input() - + # Engine step. request_outputs = self.engine.step() - + # Stream results to output socket. self.stream_outputs(request_outputs) @@ -214,9 +215,9 @@ def wait_for_new_input(self): logger.debug("Waiting for new request.") def stream_outputs(self, request_outputs: List[RequestOutput]): - self.output_socket.send_multipart( - (pickle.dumps(request_outputs),), copy=False) - + self.output_socket.send_multipart((pickle.dumps(request_outputs), ), + copy=False) + def awk_check_health(self): self.health_socket.send_multipart( (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) @@ -235,7 +236,7 @@ def maybe_handle_new_input(self): self._handle_utility_request(request) else: raise ValueError(f"Unknown RPCRequest: {request}") - + def _handle_generate_request(self, request: RPCGenerateRequest): self.engine.add_request( request_id=request.request_id, @@ -248,7 +249,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request([request.request_id]) - + def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: self.engine.do_log_stats() @@ -256,13 +257,12 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.check_health() # Special check_health channel for awk check health. self.awk_check_health() - -def run_mp_engine(engine_args: AsyncEngineArgs, - usage_context: UsageContext, + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - engine = MPLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path) + engine = MPLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) engine.run_background_loop() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ef8f98a3889b..b7c0cee1af8b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,6 +21,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf: enable +from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger @@ -37,9 +40,6 @@ EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) -# yapf: enable -from vllm.engine.multiprocessing.mp_client import MPEngineClient -from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -84,7 +84,7 @@ async def _force_log(): while True: await asyncio.sleep(1.) await async_engine_client.do_log_stats() - + if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) @@ -170,12 +170,12 @@ async def build_async_engine_client_from_engine_args( # so we need to spawn a new process context = multiprocessing.get_context("spawn") - engine_process = context.Process( - target=run_mp_engine, - args=(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path)) + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) engine_process.start() - logger.info("Started engine process with PID %d", - engine_process.pid) + logger.info("Started engine process with PID %d", engine_process.pid) try: while True: @@ -184,9 +184,8 @@ async def build_async_engine_client_from_engine_args( break except TimeoutError: if not engine_process.is_alive(): - logger.error( - "Engine process died before responding " - "to readiness probe") + logger.error("Engine process died before responding " + "to readiness probe") yield None return From ae4564c239af500cda87c72118ccd83120388f8a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:50:13 +0000 Subject: [PATCH 022/116] awk -> ack --- vllm/engine/multiprocessing/mp_client.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index ba3269c252ba..201a3c031770 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -55,7 +55,7 @@ def __init__(self, ipc_path: str): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") - # IPC path for awk of check_health requests. + # IPC path for ack of check_health requests. self.health_socket: Socket = self.context.socket(zmq.constants.PULL) self.health_socket.connect(f"{ipc_path}_health_socket") @@ -168,7 +168,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, # I don't think so, b/c we are using PUSH/PULL w/out identities so no # way to preserve order. - async def _awk_one_way_rpc_request( + async def _ack_one_way_rpc_request( self, timeout: int, expected_str: str, @@ -206,8 +206,8 @@ async def _wait_for_server_rpc(self, socket: Socket): request = RPCStartupRequest.IS_SERVER_READY await socket.send_multipart((cloudpickle.dumps(request), )) - # Raises TimeoutError if not awk, causing a retry. - await self._awk_one_way_rpc_request( + # Raises TimeoutError if not ack, causing a retry. + await self._ack_one_way_rpc_request( expected_str=VLLM_RPC_SUCCESS_STR, timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, error_message="Unable to start RPC Server", @@ -337,11 +337,6 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception - while not finished: request_output = await queue.get() if isinstance(request_output, BaseException): @@ -372,12 +367,12 @@ async def check_health(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) - # Await awknoledgement from MPLLMEngine. + # Await acknowledgement from MPLLMEngine. # Note: these requests are not necessarily serial. # I.e. if two clients A, B send CHECK_HEALTH, the # response to A could actually be the call send by B. # TODO: is this bad? - await self._awk_one_way_rpc_request( + await self._ack_one_way_rpc_request( timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, expected_str=VLLM_RPC_SUCCESS_STR, error_message="Check health timeout.", From f9ccecc7048f28b3261658d28879abbf5b3b1e42 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:53:08 +0000 Subject: [PATCH 023/116] add better shutdown --- vllm/engine/multiprocessing/mp_client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 201a3c031770..0fd1afe953c1 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -122,8 +122,13 @@ def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. + self.output_socket.close() + self.input_socket.close() + self.health_socket.close() self.context.destroy(linger=0) + # TODO: cancel the handler task. + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, From 89b730b9c650864a4b2c3a4eda2378f616ae1eb8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:55:06 +0000 Subject: [PATCH 024/116] cleanup comment --- vllm/engine/multiprocessing/mp_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 72ced337a160..5bdcc419de2f 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -86,7 +86,7 @@ def __init__(self, @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - """Creates an RPCLLM engine from the engine arguments.""" + """Creates an MPLLMEngine from the engine arguments.""" engine_config = engine_args.create_engine_config() From f3dc82b584f3d2db67c87d5e8fc12d498189e902 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:59:20 +0000 Subject: [PATCH 025/116] more awk --> ack --- vllm/engine/multiprocessing/mp_llm_engine.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 5bdcc419de2f..106f7a2fb805 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -124,17 +124,7 @@ def _init_engine(self, *args, elif self.worker_use_ray: engine_class = ray.remote(num_cpus=0)(self._engine_class).remote else: - # FIXME(woosuk): This is a bit hacky. Be careful when changing the - # order of the arguments. - cache_config = kwargs["cache_config"] - parallel_config = kwargs["parallel_config"] - if (parallel_config.tensor_parallel_size == 1 - and parallel_config.pipeline_parallel_size == 1): - num_gpus = cache_config.gpu_memory_utilization - else: - num_gpus = 1 - engine_class = ray.remote(num_gpus=num_gpus)( - self._engine_class).remote + raise NotImplementedError("Not supported yet!") return engine_class(*args, **kwargs) def run_background_loop(self): @@ -218,7 +208,7 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart((pickle.dumps(request_outputs), ), copy=False) - def awk_check_health(self): + def ack_check_health(self): self.health_socket.send_multipart( (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) @@ -255,8 +245,7 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() - # Special check_health channel for awk check health. - self.awk_check_health() + self.ack_check_health() def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, From ac97a9ebef67308dfa963279082a798971243349 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:00:21 +0000 Subject: [PATCH 026/116] use constant --- vllm/engine/multiprocessing/mp_llm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 106f7a2fb805..aca11b9293c2 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -21,6 +21,7 @@ logger = init_logger(__name__) +POLLING_TIMEOUT_MS = 10000 class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -201,7 +202,7 @@ def run_engine_loop(self) -> None: self.stream_outputs(request_outputs) def wait_for_new_input(self): - while self.input_socket.poll(timeout=10000) == 0: + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new request.") def stream_outputs(self, request_outputs: List[RequestOutput]): From becd7abe426e0437fde18b854a08785d7b914db0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:00:30 +0000 Subject: [PATCH 027/116] format --- vllm/engine/multiprocessing/mp_llm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index aca11b9293c2..f839a272e40f 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -23,6 +23,7 @@ POLLING_TIMEOUT_MS = 10000 + class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. From b7f49edd6da76a29339ee3d0954d2a583c7af642 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:03:06 +0000 Subject: [PATCH 028/116] remove set to None --- vllm/entrypoints/openai/api_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b7c0cee1af8b..cca08e11912f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -134,7 +134,6 @@ async def build_async_engine_client_from_engine_args( yield async_engine_client finally: async_engine_client.shutdown_background_loop() - async_engine_client = None #TODO return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -207,8 +206,6 @@ async def build_async_engine_client_from_engine_args( from prometheus_client import multiprocess multiprocess.mark_process_dead(engine_process.pid) - async_engine_client = None #TODO - router = APIRouter() From d0f964158de7d3db00fd90796207eadcf0fac862 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Sep 2024 20:50:41 -0700 Subject: [PATCH 029/116] Remove redundant pass --- vllm/engine/multiprocessing/mp_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 0fd1afe953c1..c49a836755d4 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -351,7 +351,6 @@ async def generate( if not self._errored: try: await self.check_health() - pass except Exception as e: self._errored = True logger.exception(repr(e)) From 5c6e5ef4964c746a51e72bb00b58d385acef8c8a Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 4 Sep 2024 19:19:40 +0000 Subject: [PATCH 030/116] review comments --- examples/openai_completion_client.py | 2 +- vllm/engine/multiprocessing/mp_llm_engine.py | 1 + vllm/entrypoints/openai/api_server.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 0b77ed4d2558..282885746d8d 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -19,7 +19,7 @@ model=model, prompt="A robot may not injure a human being", stream=stream, - max_tokens=100) + logprobs=3) print("Completion results:") if stream: diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index f839a272e40f..5e19fe1a787c 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -114,6 +114,7 @@ def cleanup(self): """Cleanup zeromq state on shutdown.""" self.input_socket.close() self.output_socket.close() + self.health_socket.close() self.ctx.destroy(linger=0) del self.engine diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e09954800d0..c6270d04c809 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,7 +21,6 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -# yapf: enable from vllm.engine.multiprocessing.mp_client import MPEngineClient from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient @@ -40,6 +39,7 @@ EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -82,7 +82,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(1.) + await asyncio.sleep(10) await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: From 25174a5288bb382d5d2175695be6ea80a73a1295 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 4 Sep 2024 19:22:22 +0000 Subject: [PATCH 031/116] format --- vllm/engine/multiprocessing/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index cf566933801e..b50700c3f7c4 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -10,6 +10,7 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" + @dataclass class RPCGenerateRequest: inputs: PromptInputs @@ -24,10 +25,12 @@ class RPCGenerateRequest: class RPCAbortRequest: request_id: str + class RPCUtilityRequest(Enum): DO_LOG_STATS = 1 CHECK_HEALTH = 2 + class RPCStartupRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 From db55c1abec31cb0d7d2a4b039a77773489bf4000 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 4 Sep 2024 19:44:43 +0000 Subject: [PATCH 032/116] add async socket reads and socket writes --- vllm/engine/multiprocessing/mp_llm_engine.py | 21 +++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 5e19fe1a787c..f5b9eb409126 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -57,6 +57,7 @@ def __init__(self, engine_use_ray: bool, *args, ipc_path: str, + use_async_sockets: bool, log_requests: bool = True, **kwargs) -> None: @@ -85,6 +86,10 @@ def __init__(self, # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" + # Indicates if we do socket read/write async with + # the GPU forward pass + self.use_async_sockets = use_async_sockets + @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): @@ -108,7 +113,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, ipc_path=ipc_path, - ) + use_async_sockets=engine_config.model_config.use_async_output_proc) def cleanup(self): """Cleanup zeromq state on shutdown.""" @@ -189,6 +194,10 @@ def run_startup_loop(self) -> None: copy=False) def run_engine_loop(self) -> None: + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self.stream_outputs_and_get_inputs + while True: # Block until there is a new request. if not self.engine.has_unfinished_requests(): @@ -200,8 +209,9 @@ def run_engine_loop(self) -> None: # Engine step. request_outputs = self.engine.step() - # Stream results to output socket. - self.stream_outputs(request_outputs) + if not self.use_async_sockets: + # Stream results to output socket. + self.stream_outputs(request_outputs) def wait_for_new_input(self): while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: @@ -211,6 +221,11 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart((pickle.dumps(request_outputs), ), copy=False) + def stream_outputs_and_get_inputs(self, + request_outputs: List[RequestOutput]): + self.stream_outputs(request_outputs) + self.maybe_handle_new_input() + def ack_check_health(self): self.health_socket.send_multipart( (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) From f97e1f268aa2967b728be3f2d8018536ad723c23 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 4 Sep 2024 14:33:29 -0700 Subject: [PATCH 033/116] Some error handling --- vllm/engine/multiprocessing/mp_client.py | 11 +++-- vllm/engine/multiprocessing/mp_llm_engine.py | 48 ++++++++++++-------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index c49a836755d4..9f19c9d22095 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -88,9 +88,14 @@ async def run_output_handler(self): else: request_id = output.request_id - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) + if request_id is not None: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + else: + # request_id None means apply to all active requests. + for queue in tuple(self.output_queues.values()): + queue.put_nowait(output) async def setup(self): """Setup the client before it starts sending server requests.""" diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index f5b9eb409126..c0bc8d0d5531 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import Iterator, List, Type, Union +from typing import Iterator, List, Optional, Tuple, Type, Union import cloudpickle import ray @@ -23,6 +23,8 @@ POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -207,28 +209,33 @@ def run_engine_loop(self) -> None: self.maybe_handle_new_input() # Engine step. - request_outputs = self.engine.step() + try: + request_outputs = self.engine.step() + except Exception as e: + # Fail all requests with this exception. + request_outputs = (None, e) if not self.use_async_sockets: # Stream results to output socket. - self.stream_outputs(request_outputs) + self.send_outputs(request_outputs) def wait_for_new_input(self): while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new request.") - def stream_outputs(self, request_outputs: List[RequestOutput]): - self.output_socket.send_multipart((pickle.dumps(request_outputs), ), - copy=False) + def send_outputs(self, request_outputs: Union[List[RequestOutput], + Tuple[Optional[str], + BaseException]]): + output_bytes = pickle.dumps(request_outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) def stream_outputs_and_get_inputs(self, request_outputs: List[RequestOutput]): - self.stream_outputs(request_outputs) + self.send_outputs(request_outputs) self.maybe_handle_new_input() def ack_check_health(self): - self.health_socket.send_multipart( - (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def maybe_handle_new_input(self): """Handle new input with non-blocking IO""" @@ -246,17 +253,22 @@ def maybe_handle_new_input(self): raise ValueError(f"Unknown RPCRequest: {request}") def _handle_generate_request(self, request: RPCGenerateRequest): - self.engine.add_request( - request_id=request.request_id, - inputs=request.inputs, - params=request.sampling_params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, - ) + request_id = request.request_id + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + ) + except Exception as e: + self.engine.abort_request(request_id) + self.send_outputs((request_id, e)) def _handle_abort_request(self, request: RPCAbortRequest): - self.engine.abort_request([request.request_id]) + self.engine.abort_request(request.request_id) def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: From dd96d3ec1e7145338ddf89d22aef938270d62361 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 13:11:15 +0000 Subject: [PATCH 034/116] remove async benchmark --- benchmarks/benchmark_throughput_async.py | 480 ----------------------- 1 file changed, 480 deletions(-) delete mode 100644 benchmarks/benchmark_throughput_async.py diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py deleted file mode 100644 index 217f11d14d30..000000000000 --- a/benchmarks/benchmark_throughput_async.py +++ /dev/null @@ -1,480 +0,0 @@ -"""Benchmark offline inference throughput.""" -import argparse -import json -import random -import time -from typing import List, Optional, Tuple - -import torch -import uvloop -from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser, merge_async_iterators - - -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - return filtered_dataset - - -async def run_vllm( - requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, - n: int, - use_beam_search: bool, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - num_scheduler_steps: int = 1, - use_v2_block_manager: bool = False, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, - disable_async_output_proc: bool = False, -) -> float: - from vllm import SamplingParams - engine_args = AsyncEngineArgs( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - num_scheduler_steps=num_scheduler_steps, - use_v2_block_manager=use_v2_block_manager, - disable_async_output_proc=disable_async_output_proc, - worker_use_ray=False, - engine_use_ray=False, - disable_log_requests=True, - ) - - decoupled = True - - async with build_async_engine_client_from_engine_args( - engine_args, not decoupled) as llm: - - # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] - for prompt, _, output_len in requests: - prompts.append(prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - )) - - generators = [] - start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - # generator = await llm.generate(prompt, sp, request_id=f"test{i}") - generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - end = time.perf_counter() - return end - start - - -def run_hf( - requests: List[Tuple[str, int, int]], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - use_beam_search: bool, - max_batch_size: int, - trust_remote_code: bool, -) -> float: - assert not use_beam_search - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) - if llm.config.model_type == "llama": - # To enable padding in the HF backend. - tokenizer.pad_token = tokenizer.eos_token - llm = llm.cuda() - - pbar = tqdm(total=len(requests)) - start = time.perf_counter() - batch: List[str] = [] - max_prompt_len = 0 - max_output_len = 0 - for i in range(len(requests)): - prompt, prompt_len, output_len = requests[i] - # Add the prompt to the batch. - batch.append(prompt) - max_prompt_len = max(max_prompt_len, prompt_len) - max_output_len = max(max_output_len, output_len) - if len(batch) < max_batch_size and i != len(requests) - 1: - # Check if we can add more requests to the batch. - _, next_prompt_len, next_output_len = requests[i + 1] - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: - # We can add more requests to the batch. - continue - - # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids - llm_outputs = llm.generate( - input_ids=input_ids.cuda(), - do_sample=not use_beam_search, - num_return_sequences=n, - temperature=1.0, - top_p=1.0, - use_cache=True, - max_new_tokens=max_output_len, - ) - # Include the decoding time. - tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) - pbar.update(len(batch)) - - # Clear the batch. - batch = [] - max_prompt_len = 0 - max_output_len = 0 - end = time.perf_counter() - return end - start - - -def run_mii( - requests: List[Tuple[str, int, int]], - model: str, - tensor_parallel_size: int, - output_len: int, -) -> float: - from mii import client, serve - llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [prompt for prompt, _, _ in requests] - - start = time.perf_counter() - llm.generate(prompts, max_new_tokens=output_len) - end = time.perf_counter() - client = client(model) - client.terminate_server() - return end - start - - -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - - # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) - if args.dataset is None: - # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len) - for _ in range(args.num_prompts)] - else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) - - if args.backend == "vllm": - coro = run_vllm( - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.num_scheduler_steps, - args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc) - - elapsed_time = uvloop.run(coro) - elif args.backend == "hf": - assert args.tensor_parallel_size == 1 - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.use_beam_search, args.hf_max_batch_size, - args.trust_remote_code) - elif args.backend == "mii": - elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, - args.output_len) - else: - raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len in requests) - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") - - # Output JSON results if specified - if args.output_json: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": total_num_tokens / elapsed_time, - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - - -if __name__ == "__main__": - parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii"], - default="vllm") - parser.add_argument("--dataset", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') - parser.add_argument( - "--num-scheduler-steps", - type=int, - default=1, - help="Maximum number of forward steps per scheduler call.") - parser.add_argument("--use-v2-block-manager", - action='store_true', - help="Enable block manager v2.") - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', - type=int, - default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - "--disable-async-output-proc", - action='store_true', - default=False, - help="Disable async output processor for vLLM backend.") - args = parser.parse_args() - if args.tokenizer is None: - args.tokenizer = args.model - if args.dataset is None: - assert args.input_len is not None - assert args.output_len is not None - else: - assert args.input_len is None - - if args.backend == "vllm": - if args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - elif args.backend == "hf": - if args.hf_max_batch_size is None: - raise ValueError("HF max batch size is required for HF backend.") - if args.quantization is not None: - raise ValueError("Quantization is only for vLLM backend.") - elif args.backend == "mii": - if args.dtype != "auto": - raise ValueError("dtype must be auto for MII backend.") - if args.n != 1: - raise ValueError("n must be 1 for MII backend.") - if args.use_beam_search: - raise ValueError("Beam search is not supported for MII backend.") - if args.quantization is not None: - raise ValueError("Quantization is only for vLLM backend.") - if args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - if args.tokenizer != args.model: - raise ValueError("Tokenizer must be the same as the model for MII " - "backend.") - main(args) From 14d4afe66a9acbfd06086c03c0665694ffeb0dbe Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 13:26:30 +0000 Subject: [PATCH 035/116] stash --- examples/openai_completion_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 282885746d8d..32d58ec5317a 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -2,7 +2,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:8001/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") From c0d0d60bcb3a3f15cdba9037901e97e4495a41f5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 20:33:38 +0000 Subject: [PATCH 036/116] adding error handling --- vllm/engine/async_llm_engine.py | 70 +---- vllm/engine/multiprocessing/__init__.py | 8 + vllm/engine/multiprocessing/mp_client.py | 96 +++---- vllm/engine/multiprocessing/mp_llm_engine.py | 253 +++++++++++-------- vllm/executor/executor_base.py | 71 +++++- 5 files changed, 272 insertions(+), 226 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 17b9ed40e41c..3b6ab0f03c16 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,7 @@ from typing_extensions import assert_never import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs @@ -15,8 +15,8 @@ from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, PromptComponents, SchedulerOutputState) from vllm.engine.metrics_types import StatLoggerBase -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.executor.executor_base import get_executor_cls +from vllm.executor.ray_utils import ray from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -650,68 +650,6 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker - @classmethod - def _get_executor_cls( - cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorAsyncBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutorAsync - executor_class = NeuronExecutorAsync - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync - executor_class = RayTPUExecutorAsync - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutorAsync - executor_class = TPUExecutorAsync - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutorAsync - executor_class = CPUExecutorAsync - elif engine_config.device_config.device_type == "openvino": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with " - "the OpenVINO backend.") - from vllm.executor.openvino_executor import OpenVINOExecutorAsync - executor_class = OpenVINOExecutorAsync - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend is None: - from vllm.executor.xpu_executor import XPUExecutorAsync - executor_class = XPUExecutorAsync - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync - executor_class = RayXPUExecutorAsync - elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.multiproc_xpu_executor import ( - MultiprocessingXPUExecutorAsync) - executor_class = MultiprocessingXPUExecutorAsync - else: - raise RuntimeError( - "Not supported distributed execution model on XPU device.") - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - executor_class = RayGPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutorAsync) - executor_class = MultiprocessingGPUExecutorAsync - else: - from vllm.executor.gpu_executor import GPUExecutorAsync - executor_class = GPUExecutorAsync - return executor_class @classmethod def from_engine_args( @@ -729,7 +667,7 @@ def from_engine_args( from vllm.executor import ray_utils ray_utils.assert_ray_available() - executor_class = cls._get_executor_cls(engine_config) + executor_class = get_executor_cls(engine_config) # Create the async LLM engine. engine = cls( diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index b50700c3f7c4..26c4d42d4232 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -9,6 +9,7 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_FAILED_STR = "FAILED" @dataclass @@ -21,6 +22,13 @@ class RPCGenerateRequest: prompt_adapter_request: Optional[PromptAdapterRequest] = None +@dataclass +class RPCGenerateError: + request_id: str + is_errored: bool + exception: BaseException + + @dataclass class RPCAbortRequest: request_id: str diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 9f19c9d22095..fd1490f8aa5d 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -12,12 +12,10 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -# yapf: disable from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) -# yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -29,6 +27,7 @@ logger = init_logger(__name__) +CHECK_HEALTH_INTERVAL_S = 10. class MPClientClosedError(Exception): """Exception class raised when the client is used post-close. @@ -46,12 +45,13 @@ class MPEngineClient: def __init__(self, ipc_path: str): self.context = zmq.asyncio.Context() self._errored = False + self._errored_with: Optional[BaseException] = None - # Send RPCGenerateRequest to the MPLLMEngine. + # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) self.input_socket.connect(f"{ipc_path}_input_socket") - # Receive streams of RequestOutput from the MPLLMEngine. + # Receive streams of RequestOutput from the MQLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") @@ -75,8 +75,18 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) + async def run_check_health_loop(self): + # Check health periodically. + while True: + await asyncio.sleep(CHECK_HEALTH_INTERVAL_S) + try: + await self._check_health_rpc(self.health_socket) + except Exception as e: + self._errored = True + self._errored_with = e + async def run_output_handler(self): - # Stream lists of RequestOutput from MPLLMEngine. + # Stream lists of RequestOutput from MQLLMEngine. while True: message: Frame = await self.output_socket.recv(copy=False) request_outputs: List[RequestOutput] = pickle.loads(message.buffer) @@ -120,7 +130,7 @@ async def setup(self): enable_lora=bool(await self._get_lora_config_rpc(socket)), ) - # Notify MPLLMEngine client is ready to start sending requests. + # Notify MQLLMEngine client is ready to start sending requests. await self._notify_ready(socket) def close(self): @@ -145,7 +155,7 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, # Make sure the server responds if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " + raise TimeoutError("RPCServer didn't reply within " f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") # Await the data from the Server. @@ -168,25 +178,29 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, return data - async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, - socket: Socket): + async def _send_one_way_rpc_request( + self, request: RPC_REQUEST_TYPE, + socket: Socket, + await_ack: bool = False, + error_message: str = "RPCRequest Failed."): """Send one-way RPC request to trigger an action.""" await socket.send_multipart((cloudpickle.dumps(request), )) - # TODO: is there a way to ack this if we are using the input_socket? - # I don't think so, b/c we are using PUSH/PULL w/out identities so no - # way to preserve order. + if await_ack: + await self._ack_one_way_rpc_request( + expected_str=VLLM_RPC_SUCCESS_STR, + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + error_message=error_message, + socket=socket) + async def _ack_one_way_rpc_request( - self, - timeout: int, - expected_str: str, - error_message: str, - socket: Socket, - ): + self, timeout: int, expected_str: str, error_message: str, socket: Socket): + "Await acknoledgement that a request succeeded." + if await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") + raise TimeoutError(f"MQLLMEngine didn't reply within {timeout}ms") frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) @@ -212,22 +226,27 @@ async def is_tracing_enabled(self) -> bool: async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" - # Readiness probe. - request = RPCStartupRequest.IS_SERVER_READY - await socket.send_multipart((cloudpickle.dumps(request), )) + self._send_one_way_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + socket=socket, + await_ack=True, + error_message="Unable to start RPC Server") - # Raises TimeoutError if not ack, causing a retry. - await self._ack_one_way_rpc_request( - expected_str=VLLM_RPC_SUCCESS_STR, - timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, - error_message="Unable to start RPC Server", - socket=socket) + async def _check_health_rpc(self, socket: Socket): + """Get current health status from the RPCServer""" + + self._send_one_way_rpc_request( + request=RPCUtilityRequest.CHECK_HEALTH, + socket=socket, + await_ack=True, + error_message="Check health failed.") async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) + request=RPCStartupRequest.CLIENT_IS_READY, + socket=socket) async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -349,6 +368,8 @@ async def generate( while not finished: request_output = await queue.get() + # todo: convert message to include status of whether the engine is errored + # such that we dont need another RPC call if isinstance(request_output, BaseException): finished = True # On exception, check if the server is still healthy @@ -370,23 +391,6 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) - - # Await acknowledgement from MPLLMEngine. - # Note: these requests are not necessarily serial. - # I.e. if two clients A, B send CHECK_HEALTH, the - # response to A could actually be the call send by B. - # TODO: is this bad? - await self._ack_one_way_rpc_request( - timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, - expected_str=VLLM_RPC_SUCCESS_STR, - error_message="Check health timeout.", - socket=self.health_socket) - async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index c0bc8d0d5531..e16785a215a0 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -6,12 +6,14 @@ import ray import zmq -from vllm import AsyncEngineArgs, AsyncLLMEngine, LLMEngine +from vllm import AsyncEngineArgs, LLMEngine +from vllm.executor.executor_base import get_executor_cls from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCGenerateRequest, RPCStartupRequest, - RPCUtilityRequest) +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, VLLM_RPC_FAILED_STR, + RPCStartupRequest, RPCUtilityRequest, + RPCGenerateRequest, RPCGenerateError, + RPCAbortRequest) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -24,9 +26,18 @@ POLLING_TIMEOUT_MS = 10000 HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) +UNHEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_FAILED_STR), ) +REQUEST_OUTPUTS_T = Union[List[RequestOutput], Tuple[str, RPCGenerateError]] +class MQEngineDeadError(RuntimeError): + pass -class MPLLMEngine: +ENGINE_DEAD_ERROR = MQEngineDeadError( + "Engine loop is not running. Inspect the output to find " + "the stacktrace of the error that caused the engine loop " + "to stop (MQEngineDeadError).") + +class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to enable use @@ -39,37 +50,34 @@ class MPLLMEngine: The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal :class:`LLMEngine.step()` and sends the RequestOutputs back over - the output_socket. + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. Args: - worker_use_ray: Whether to use Ray for model workers. Required for - distributed execution. Should be the same as - `parallel_config.worker_use_ray`. - engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the - async frontend will be executed in a separate process as the - model workers. - async_engine_args: AsyncLLMEngine args + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ - _engine_class: Type[LLMEngine] = LLMEngine - def __init__(self, - worker_use_ray: bool, - engine_use_ray: bool, *args, ipc_path: str, use_async_sockets: bool, log_requests: bool = True, **kwargs) -> None: - - if engine_use_ray: - raise NotImplementedError("Not yet supported.") - - self.worker_use_ray = worker_use_ray - self.engine_use_ray = engine_use_ray + self.engine = LLMEngine(*args, **kwargs) self.log_requests = log_requests - self.engine = self._init_engine(*args, **kwargs) + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback self.ctx = zmq.Context() # type: ignore[attr-defined] @@ -88,27 +96,26 @@ def __init__(self, # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" - # Indicates if we do socket read/write async with - # the GPU forward pass - self.use_async_sockets = use_async_sockets + # Error state. + self._errored = False @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - """Creates an MPLLMEngine from the engine arguments.""" + """Creates an MQLLMEngine from the engine arguments.""" engine_config = engine_args.create_engine_config() if engine_args.engine_use_ray: - from vllm.executor import ray_utils - ray_utils.assert_ray_available() + raise NotImplementedError( + "--engine-use-ray is not supported for MQLLMEngine. " + "Launch with --disable-frontend-multiprocessing if you " + "need to deploy with this flag (not recommended).") - # TODO: better abstraction? - executor_class = AsyncLLMEngine._get_executor_cls(engine_config) + executor_class = get_executor_cls(engine_config) return cls( executor_class.uses_ray, - engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -125,27 +132,6 @@ def cleanup(self): self.ctx.destroy(linger=0) del self.engine - def _init_engine(self, *args, - **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: - """Initialize the LLMEngine""" - - if not self.engine_use_ray: - engine_class = self._engine_class - elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(self._engine_class).remote - else: - raise NotImplementedError("Not supported yet!") - return engine_class(*args, **kwargs) - - def run_background_loop(self): - """Entrypoint that kicks off the background processing loop.""" - - # Allow RPCClient to query data in startup phase. - self.run_startup_loop() - - # Kick off core processing loop. - self.run_engine_loop() - @contextmanager def make_data_socket( self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] @@ -195,65 +181,74 @@ def run_startup_loop(self) -> None: socket.send_multipart((identity, pickle.dumps(e)), copy=False) - def run_engine_loop(self) -> None: - if self.use_async_sockets: - self.engine.process_request_outputs_callback = \ - self.stream_outputs_and_get_inputs - + def run_engine_loop(self): + """Entrypoint for core busy loop""" while True: - # Block until there is a new request. - if not self.engine.has_unfinished_requests(): - self.wait_for_new_input() + # Poll until there is work to do. + self.poll_for_work() - # Handle any new input from the input socket. - self.maybe_handle_new_input() + # Handle any new data. + self.handle_new_input() # Engine step. - try: - request_outputs = self.engine.step() - except Exception as e: - # Fail all requests with this exception. - request_outputs = (None, e) - - if not self.use_async_sockets: - # Stream results to output socket. - self.send_outputs(request_outputs) - - def wait_for_new_input(self): - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for new request.") - - def send_outputs(self, request_outputs: Union[List[RequestOutput], - Tuple[Optional[str], - BaseException]]): - output_bytes = pickle.dumps(request_outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) + request_outputs = self.engine_step() - def stream_outputs_and_get_inputs(self, - request_outputs: List[RequestOutput]): - self.send_outputs(request_outputs) - self.maybe_handle_new_input() + # Stream results if neeeded + if (not self.use_async_sockets or + isinstance(request_outputs, RPCGenerateError)): + self._send_request_outputs(request_outputs) - def ack_check_health(self): - self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - def maybe_handle_new_input(self): - """Handle new input with non-blocking IO""" - while self.input_socket.poll(timeout=0) != 0: - message = self.input_socket.recv(copy=False) - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - self._handle_generate_request(request) - elif isinstance(request, RPCAbortRequest): - self._handle_abort_request(request) - elif isinstance(request, RPCUtilityRequest): - self._handle_utility_request(request) - else: - raise ValueError(f"Unknown RPCRequest: {request}") + def poll_for_work(self): + """Poll the socket until there is work to do.""" + if not self.engine.has_unfinished_requests(): + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for new requests.") + + + def engine_step(self) -> REQUEST_OUTPUTS_T: + """Engine step wrapper with error handling.""" + try: + request_outputs = self.engine.step() + except Exception as e: + self._errored = True + request_outputs = RPCGenerateError(request_id=None, + is_errored=self._errored, + exception=e) + finally: + return request_outputs + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + message = self.input_socket.recv(copy=False) + request = cloudpickle.loads(message.buffer) + + if isinstance(request, RPCGenerateRequest): + # Exceptions in RPCGenerateRequest will be caught + # by the handler, meaning any recoverable exceptions + # to only impact that request (and not crash the server) + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUtilityRequest): + self._handle_utility_request(request) + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + except Exception as e: + self._errored = True + logger.exception(repr(e)) + self._send_unhealthy() def _handle_generate_request(self, request: RPCGenerateRequest): request_id = request.request_id + + if self._errored: + e = RPCGenerateError(request_id, self._errored, ENGINE_DEAD_ERROR) + self._send_request_outputs(e) + try: self.engine.add_request( request_id=request_id, @@ -261,27 +256,63 @@ def _handle_generate_request(self, request: RPCGenerateRequest): params=request.sampling_params, lora_request=request.lora_request, trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, - ) - except Exception as e: + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as err: self.engine.abort_request(request_id) - self.send_outputs((request_id, e)) + + # We do not set self._errored = True here, + # since the error is due to an issue adding this + # request to the engine, rather than an issue with + # the engine itself. + e = RPCGenerateError(request_id, self._errored, err) + self._send_request_outputs(e) + def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request(request.request_id) - def _handle_utility_request(self, request: RPCUtilityRequest): + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + + def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() - self.ack_check_health() + self._send_healthy() + + + def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + + output_bytes = pickle.dumps(request_outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self): + """Send UNHEALTHY message to RPCClient.""" + self.health_socket.send_multipart(UNHEALTHY_RESPONSE, copy=False) + + def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_request_outputs(request_outputs) + self.handle_new_input() def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - engine = MPLLMEngine.from_engine_args(engine_args=engine_args, + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, usage_context=usage_context, ipc_path=ipc_path) - engine.run_background_loop() + engine.run_startup_loop() + engine.run_engine_loop() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c298..bf8c82dc6804 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple, Type -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, +from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, + LoadConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -148,3 +150,66 @@ async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" self.check_health() + + +def get_executor_cls( + engine_config: EngineConfig) -> Type["ExecutorAsyncBase"]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorAsyncBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorAsyncBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutorAsync + executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync + executor_class = RayTPUExecutorAsync + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutorAsync + executor_class = TPUExecutorAsync + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "openvino": + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with " + "the OpenVINO backend.") + from vllm.executor.openvino_executor import OpenVINOExecutorAsync + executor_class = OpenVINOExecutorAsync + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend is None: + from vllm.executor.xpu_executor import XPUExecutorAsync + executor_class = XPUExecutorAsync + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync + executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync + else: + raise RuntimeError( + "Not supported distributed execution model on XPU device.") + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync + executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync + else: + from vllm.executor.gpu_executor import GPUExecutorAsync + executor_class = GPUExecutorAsync + return executor_class From b7c1fcc8e8d50c84817aec16db7c448c524fa622 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 22:07:11 +0000 Subject: [PATCH 037/116] error handling --- vllm/engine/multiprocessing/__init__.py | 28 ++- vllm/engine/multiprocessing/mp_client.py | 210 ++++++++++++------- vllm/engine/multiprocessing/mp_llm_engine.py | 37 ++-- vllm/entrypoints/openai/api_server.py | 4 +- 4 files changed, 175 insertions(+), 104 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 26c4d42d4232..2372e3bd6c22 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,16 +1,26 @@ from dataclasses import dataclass from enum import Enum -from typing import Mapping, Optional, Union +from typing import List, Mapping, Optional, Union from vllm.inputs import PromptInputs +from vllm.outputs import RequestOutput from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -# Success string used for RPC instructions. + VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_FAILED_STR = "FAILED" +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + @dataclass class RPCGenerateRequest: @@ -24,8 +34,8 @@ class RPCGenerateRequest: @dataclass class RPCGenerateError: - request_id: str - is_errored: bool + request_id: Optional[str] + is_engine_errored: bool exception: BaseException @@ -50,5 +60,13 @@ class RPCStartupRequest(Enum): CLIENT_IS_READY = 8 -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest, RPCStartupRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput, RPCGenerateError]] + +ENGINE_DEAD_ERROR = MQEngineDeadError( + "Engine loop is not running. Inspect the output to find " + "the stacktrace of the error that caused the engine loop " + "to stop (MQEngineDeadError).") + diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index fd1490f8aa5d..a3884f20e219 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -12,10 +12,13 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCGenerateRequest, RPCStartupRequest, - RPCUtilityRequest) +from vllm.engine.multiprocessing import (IPC_INPUT_EXT, IPC_OUTPUT_EXT, + IPC_HEALTH_EXT, IPC_DATA_EXT, + RPC_REQUEST_T, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, + ENGINE_DEAD_ERROR, RPCAbortRequest, + RPCGenerateRequest, RPCGenerateError, + RPCStartupRequest, RPCUtilityRequest) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -27,7 +30,6 @@ logger = init_logger(__name__) -CHECK_HEALTH_INTERVAL_S = 10. class MPClientClosedError(Exception): """Exception class raised when the client is used post-close. @@ -40,31 +42,57 @@ class MPClientClosedError(Exception): """ -class MPEngineClient: +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + AsyncEngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ def __init__(self, ipc_path: str): self.context = zmq.asyncio.Context() self._errored = False - self._errored_with: Optional[BaseException] = None # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.input_socket.connect(f"{ipc_path}_input_socket") + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") # Receive streams of RequestOutput from the MQLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect(f"{ipc_path}_output_socket") + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") # IPC path for ack of check_health requests. self.health_socket: Socket = self.context.socket(zmq.constants.PULL) - self.health_socket.connect(f"{ipc_path}_health_socket") + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}_data_socket" + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_handler = asyncio.create_task(self.run_output_handler()) + self.output_loop = asyncio.create_task( + self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS)) @contextmanager def get_data_socket(self) -> Iterator[Socket]: @@ -75,37 +103,63 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_check_health_loop(self): - # Check health periodically. - while True: - await asyncio.sleep(CHECK_HEALTH_INTERVAL_S) - try: - await self._check_health_rpc(self.health_socket) - except Exception as e: - self._errored = True - self._errored_with = e - - async def run_output_handler(self): - # Stream lists of RequestOutput from MQLLMEngine. - while True: - message: Frame = await self.output_socket.recv(copy=False) - request_outputs: List[RequestOutput] = pickle.loads(message.buffer) - - for output in request_outputs: - if isinstance(output, tuple): - # Exception case - request_id, output = output + async def run_check_health_loop(self, timeout: int): + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._check_health_rpc(self.health_socket) else: - request_id = output.request_id - - if request_id is not None: - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) + # Server sent a health status message unprompted. + self._check_success(error_message="Health check failed", + socket=self.health_socket) + except asyncio.CancelledError: + logger.info("Shutting down MQLLMEngineClient check health loop.") + except Exception as e: + logger.exception(repr(e)) + self._errored = True + + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to request Queues""" + + try: + while True: + message: Frame = await self.output_socket.recv(copy=False) + request_outputs: REQUEST_OUTPUTS_T = pickle.loads(message.buffer) + + if isinstance(request_outputs, RPCGenerateError): + error: RPCGenerateError = request_outputs + + if error.is_engine_errored: + self._errored = True + + if error.request_id is None: + # Apply exception to all active requests. + + # TODO: this sends the exceptions to the PENDING requests too. + # Do we want this? Shouldn't we be sending EngineDeadError to PENDING? + for queue in tuple(self.output_queues.values()): + queue.put_nowait(error.exception) + else: + queue = self.output_queues.get(error.request_id) + if queue is not None: + queue.put_nowait(error.exception) else: - # request_id None means apply to all active requests. - for queue in tuple(self.output_queues.values()): - queue.put_nowait(output) + # TODO: what should we do if the RPCServer sends back a raw exception? + assert not isinstance(request_outputs, BaseException), ( + "Got unhandled raw unhandled Exception from RPCServer. " + "This should never happen.") + + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get(request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.info("Shutting down MQLLMEngineClient output handler.") + async def setup(self): """Setup the client before it starts sending server requests.""" @@ -142,7 +196,9 @@ def close(self): self.health_socket.close() self.context.destroy(linger=0) - # TODO: cancel the handler task. + # Cancel background tasks. + self.health_loop.cancel() + self.output_loop.cancel() async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, @@ -153,7 +209,7 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, # Ping RPCServer with a request. await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) - # Make sure the server responds + # Make sure the server responds in time. if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("RPCServer didn't reply within " f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") @@ -179,7 +235,8 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, return data async def _send_one_way_rpc_request( - self, request: RPC_REQUEST_TYPE, + self, + request: RPC_REQUEST_T, socket: Socket, await_ack: bool = False, error_message: str = "RPCRequest Failed."): @@ -188,24 +245,24 @@ async def _send_one_way_rpc_request( await socket.send_multipart((cloudpickle.dumps(request), )) if await_ack: - await self._ack_one_way_rpc_request( - expected_str=VLLM_RPC_SUCCESS_STR, - timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + await self._await_ack( error_message=error_message, socket=socket) - - async def _ack_one_way_rpc_request( - self, timeout: int, expected_str: str, error_message: str, socket: Socket): + async def _await_ack(self, error_message: str, socket: Socket): "Await acknoledgement that a request succeeded." - if await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"MQLLMEngine didn't reply within {timeout}ms") + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS}ms") + + await self._check_success(error_message, socket) + async def _check_success(self, error_message: str, socket: Socket): frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) - if not isinstance(response, str) or response != expected_str: + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: if isinstance(response, Exception): logger.error(error_message) raise response @@ -306,19 +363,10 @@ async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - # Suppress timeouts and MPClientClosedError. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(MPClientClosedError, TimeoutError): + with suppress(MPClientClosedError): await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), socket=self.input_socket) + request=RPCAbortRequest(request_id), + socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" @@ -327,6 +375,16 @@ async def do_log_stats(self): request=RPCUtilityRequest.DO_LOG_STATS, socket=self.input_socket) + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored if + the engine is unhealth. So check_health just raises + an ENGINE_DEAD_ERROR if we find self._errored + """ + if self._errored: + raise ENGINE_DEAD_ERROR + @property def is_running(self) -> bool: return not self._errored @@ -353,9 +411,8 @@ async def generate( queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue - finished = False - try: + try: # Send RPCGenerateRequest to the RPCServer. await self.input_socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( @@ -366,27 +423,24 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) + # Stream from the output queue. + finished = False while not finished: request_output = await queue.get() - # todo: convert message to include status of whether the engine is errored - # such that we dont need another RPC call + if isinstance(request_output, BaseException): - finished = True - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health() - except Exception as e: - self._errored = True - logger.exception(repr(e)) raise request_output finished = request_output.finished yield request_output finally: + # TODO: check if aborted requests are getting here. + # TODO: check if requests + + # Remove output stream. self.output_queues.pop(request_id) + # Request was canceled by the client. if not finished and not self._errored: await self.abort(request_id) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index e16785a215a0..5fe8002e35e7 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -10,7 +10,10 @@ from vllm.executor.executor_base import get_executor_cls from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, VLLM_RPC_FAILED_STR, +from vllm.engine.multiprocessing import (IPC_INPUT_EXT, IPC_OUTPUT_EXT, + IPC_HEALTH_EXT, IPC_DATA_EXT, + VLLM_RPC_SUCCESS_STR, VLLM_RPC_FAILED_STR, + ENGINE_DEAD_ERROR, REQUEST_OUTPUTS_T, RPCStartupRequest, RPCUtilityRequest, RPCGenerateRequest, RPCGenerateError, RPCAbortRequest) @@ -24,18 +27,8 @@ logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 - HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) UNHEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_FAILED_STR), ) -REQUEST_OUTPUTS_T = Union[List[RequestOutput], Tuple[str, RPCGenerateError]] - -class MQEngineDeadError(RuntimeError): - pass - -ENGINE_DEAD_ERROR = MQEngineDeadError( - "Engine loop is not running. Inspect the output to find " - "the stacktrace of the error that caused the engine loop " - "to stop (MQEngineDeadError).") class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -83,18 +76,18 @@ def __init__(self, # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(f"{ipc_path}_input_socket") + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") # Send output stream back to client. self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind(f"{ipc_path}_output_socket") + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") # Send health status back to client. self.health_socket = self.ctx.socket(zmq.constants.PUSH) - self.health_socket.bind(f"{ipc_path}_health_socket") + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}_data_socket" + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Error state. self._errored = False @@ -222,7 +215,8 @@ def handle_new_input(self): """Handle new input from the socket""" try: while self.input_socket.poll(timeout=0) != 0: - message = self.input_socket.recv(copy=False) + # TODO: do we need error handling around the pickling? + message = self.input_socket.recv(copy=False) request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): @@ -235,7 +229,8 @@ def handle_new_input(self): elif isinstance(request, RPCUtilityRequest): self._handle_utility_request(request) else: - raise ValueError(f"Unknown RPCRequest: {request}") + raise ValueError( + "Unknown RPCRequest Type: {request}") except Exception as e: self._errored = True @@ -243,6 +238,7 @@ def handle_new_input(self): self._send_unhealthy() def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the LLMEngine.""" request_id = request.request_id if self._errored: @@ -274,7 +270,6 @@ def _handle_generate_request(self, request: RPCGenerateRequest): def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request(request.request_id) - if self.log_requests: logger.info("Aborted request %s.", request.request_id) @@ -290,6 +285,8 @@ def _handle_utility_request(self, request: RPCUtilityRequest): def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" + # TODO: do we need error handling around the pickling? + output_bytes = pickle.dumps(request_outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) @@ -297,11 +294,13 @@ def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - + + def _send_unhealthy(self): """Send UNHEALTHY message to RPCClient.""" self.health_socket.send_multipart(UNHEALTHY_RESPONSE, copy=False) + def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): """Callback used by engine to make socket handling async with GPU.""" self._send_request_outputs(request_outputs) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f97c84020e84..79ab3d586a68 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,7 +21,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_client import MQLLMEngineClient from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.launcher import serve_http @@ -167,7 +167,7 @@ async def build_async_engine_client_from_engine_args( # Build RPCClient, which conforms to AsyncEngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) - mp_engine_client = MPEngineClient(ipc_path) + mp_engine_client = MQLLMEngineClient(ipc_path) # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, From a661b7659c72339cb31fa13d3b1185d302742570 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 22:21:35 +0000 Subject: [PATCH 038/116] added --- vllm/engine/multiprocessing/__init__.py | 2 +- .../{mp_client.py => client.py} | 52 ++++++++++++------- .../{mp_llm_engine.py => engine.py} | 0 vllm/entrypoints/openai/api_server.py | 2 - 4 files changed, 35 insertions(+), 21 deletions(-) rename vllm/engine/multiprocessing/{mp_client.py => client.py} (90%) rename vllm/engine/multiprocessing/{mp_llm_engine.py => engine.py} (100%) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 2372e3bd6c22..c0cf21d274ea 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -63,7 +63,7 @@ class RPCStartupRequest(Enum): RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest, RPCStartupRequest] -REQUEST_OUTPUTS_T = Union[List[RequestOutput, RPCGenerateError]] +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCGenerateError] ENGINE_DEAD_ERROR = MQEngineDeadError( "Engine loop is not running. Inspect the output to find " diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/client.py similarity index 90% rename from vllm/engine/multiprocessing/mp_client.py rename to vllm/engine/multiprocessing/client.py index a3884f20e219..5173e53afdaf 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/client.py @@ -113,8 +113,10 @@ async def run_check_health_loop(self, timeout: int): # Server sent a health status message unprompted. self._check_success(error_message="Health check failed", socket=self.health_socket) + except asyncio.CancelledError: logger.info("Shutting down MQLLMEngineClient check health loop.") + except Exception as e: logger.exception(repr(e)) self._errored = True @@ -128,23 +130,33 @@ async def run_output_handler_loop(self): message: Frame = await self.output_socket.recv(copy=False) request_outputs: REQUEST_OUTPUTS_T = pickle.loads(message.buffer) - if isinstance(request_outputs, RPCGenerateError): - error: RPCGenerateError = request_outputs + if isinstance(request_outputs, BaseException): - if error.is_engine_errored: + if isinstance(request_outputs, RPCGenerateError): + error: RPCGenerateError = request_outputs + request_id = error.request_id + if error.is_engine_errored: + self._errored = True + exception = error.exception + else: + # MPLLMEngine should always return an RPCGenerateError + # if the error handling is graceful. If we are here, + # we are in a bad state and should shut down the server. + error: BaseException = request_output + logger.warning( + "Got raw Exception {error} from MPLLMEngine. " + "This should never happen.") self._errored = True + request_id = None + exception = error - if error.request_id is None: - # Apply exception to all active requests. - - # TODO: this sends the exceptions to the PENDING requests too. - # Do we want this? Shouldn't we be sending EngineDeadError to PENDING? + if request_id is None: for queue in tuple(self.output_queues.values()): - queue.put_nowait(error.exception) + queue.put_nowait(exception) else: - queue = self.output_queues.get(error.request_id) + queue = self.output_queues.get(request_id) if queue is not None: - queue.put_nowait(error.exception) + queue.put_nowait(exception) else: # TODO: what should we do if the RPCServer sends back a raw exception? assert not isinstance(request_outputs, BaseException), ( @@ -408,12 +420,15 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() + if self._errored: + raise ENGINE_DEAD_ERROR + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue try: - # Send RPCGenerateRequest to the RPCServer. + # 2) Send the RPCGenerateRequest to the MQLLMEngine. await self.input_socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( inputs=inputs, @@ -423,7 +438,9 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # Stream from the output queue. + # 3) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. finished = False while not finished: request_output = await queue.get() @@ -435,12 +452,11 @@ async def generate( yield request_output finally: + # TODO: check if excepted requests are getting here. # TODO: check if aborted requests are getting here. - # TODO: check if requests - # Remove output stream. self.output_queues.pop(request_id) - + # Request was canceled by the client. if not finished and not self._errored: await self.abort(request_id) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/engine.py similarity index 100% rename from vllm/engine/multiprocessing/mp_llm_engine.py rename to vllm/engine/multiprocessing/engine.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 79ab3d586a68..3e4faa80be59 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,8 +41,6 @@ TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion From 5d00f3ac07f6500e925596666f36e418e2a157c2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 7 Sep 2024 23:04:05 +0000 Subject: [PATCH 039/116] formatting in place --- vllm/engine/async_llm_engine.py | 1 - vllm/engine/multiprocessing/__init__.py | 12 +-- vllm/engine/multiprocessing/client.py | 133 ++++++++++++------------ vllm/engine/multiprocessing/engine.py | 76 ++++++-------- vllm/entrypoints/openai/api_server.py | 4 +- vllm/executor/executor_base.py | 12 +-- 6 files changed, 109 insertions(+), 129 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3b6ab0f03c16..78cdfa1dedf7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -650,7 +650,6 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker - @classmethod def from_engine_args( cls, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index c0cf21d274ea..ffcdd2ad2c92 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,12 +3,11 @@ from typing import List, Mapping, Optional, Union from vllm.inputs import PromptInputs -from vllm.outputs import RequestOutput from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams - VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_FAILED_STR = "FAILED" @@ -33,7 +32,7 @@ class RPCGenerateRequest: @dataclass -class RPCGenerateError: +class RPCGenerateError(BaseException): request_id: Optional[str] is_engine_errored: bool exception: BaseException @@ -60,13 +59,12 @@ class RPCStartupRequest(Enum): CLIENT_IS_READY = 8 -RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest, RPCStartupRequest] +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest, + RPCStartupRequest] -REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCGenerateError] +REQUEST_OUTPUTS_T = Union[List[RequestOutput], BaseException] ENGINE_DEAD_ERROR = MQEngineDeadError( "Engine loop is not running. Inspect the output to find " "the stacktrace of the error that caused the engine loop " "to stop (MQEngineDeadError).") - diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 5173e53afdaf..bc2c27c9331e 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -1,8 +1,8 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union) +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) import cloudpickle import zmq @@ -12,13 +12,13 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.engine.multiprocessing import (IPC_INPUT_EXT, IPC_OUTPUT_EXT, - IPC_HEALTH_EXT, IPC_DATA_EXT, - RPC_REQUEST_T, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, - ENGINE_DEAD_ERROR, RPCAbortRequest, - RPCGenerateRequest, RPCGenerateError, - RPCStartupRequest, RPCUtilityRequest) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, + RPCAbortRequest, RPCGenerateError, + RPCGenerateRequest, RPCStartupRequest, + RPCUtilityRequest) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -87,12 +87,10 @@ def __init__(self, ipc_path: str): # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_loop = asyncio.create_task( - self.run_output_handler_loop()) - + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + # Loop to check health of the LLMEngine periodically. - self.health_loop = asyncio.create_task( - self.run_check_health_loop(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS)) + self.health_loop: Optional[asyncio.Task] = None @contextmanager def get_data_socket(self) -> Iterator[Socket]: @@ -104,15 +102,33 @@ def get_data_socket(self) -> Iterator[Socket]: socket.close(linger=0) async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + try: while True: if await self.health_socket.poll(timeout=timeout) == 0: # Wakeup every N seconds and do a health probe. - await self._check_health_rpc(self.health_socket) + await self._send_one_way_rpc_request( + RPCUtilityRequest.CHECK_HEALTH, self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) else: # Server sent a health status message unprompted. - self._check_success(error_message="Health check failed", - socket=self.health_socket) + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe complete.") except asyncio.CancelledError: logger.info("Shutting down MQLLMEngineClient check health loop.") @@ -121,28 +137,27 @@ async def run_check_health_loop(self, timeout: int): logger.exception(repr(e)) self._errored = True - async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to request Queues""" - + try: while True: message: Frame = await self.output_socket.recv(copy=False) - request_outputs: REQUEST_OUTPUTS_T = pickle.loads(message.buffer) + request_outputs: REQUEST_OUTPUTS_T = pickle.loads( + message.buffer) if isinstance(request_outputs, BaseException): - if isinstance(request_outputs, RPCGenerateError): - error: RPCGenerateError = request_outputs - request_id = error.request_id - if error.is_engine_errored: + generate_error: RPCGenerateError = request_outputs + request_id = generate_error.request_id + if generate_error.is_engine_errored: self._errored = True - exception = error.exception + exception = generate_error.exception else: # MPLLMEngine should always return an RPCGenerateError # if the error handling is graceful. If we are here, # we are in a bad state and should shut down the server. - error: BaseException = request_output + error: BaseException = request_outputs logger.warning( "Got raw Exception {error} from MPLLMEngine. " "This should never happen.") @@ -151,28 +166,23 @@ async def run_output_handler_loop(self): exception = error if request_id is None: - for queue in tuple(self.output_queues.values()): - queue.put_nowait(exception) + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) else: queue = self.output_queues.get(request_id) if queue is not None: queue.put_nowait(exception) else: - # TODO: what should we do if the RPCServer sends back a raw exception? - assert not isinstance(request_outputs, BaseException), ( - "Got unhandled raw unhandled Exception from RPCServer. " - "This should never happen.") - # Put each output into the appropriate steam. for request_output in request_outputs: - queue = self.output_queues.get(request_output.request_id) + queue = self.output_queues.get( + request_output.request_id) if queue is not None: queue.put_nowait(request_output) - + except asyncio.CancelledError: logger.info("Shutting down MQLLMEngineClient output handler.") - async def setup(self): """Setup the client before it starts sending server requests.""" @@ -196,6 +206,11 @@ async def setup(self): enable_lora=bool(await self._get_lora_config_rpc(socket)), ) + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop( + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS)) + # Notify MQLLMEngine client is ready to start sending requests. await self._notify_ready(socket) @@ -209,7 +224,8 @@ def close(self): self.context.destroy(linger=0) # Cancel background tasks. - self.health_loop.cancel() + if self.health_loop is not None: + self.health_loop.cancel() self.output_loop.cancel() async def _send_get_data_rpc_request(self, request: RPCStartupRequest, @@ -246,23 +262,14 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, return data - async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_T, - socket: Socket, - await_ack: bool = False, - error_message: str = "RPCRequest Failed."): + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_T, + socket: Socket): """Send one-way RPC request to trigger an action.""" await socket.send_multipart((cloudpickle.dumps(request), )) - if await_ack: - await self._await_ack( - error_message=error_message, - socket=socket) - async def _await_ack(self, error_message: str, socket: Socket): - "Await acknoledgement that a request succeeded." + "Await acknowledgement that a request succeeded." if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("MQLLMEngine didn't reply within " @@ -295,27 +302,17 @@ async def is_tracing_enabled(self) -> bool: async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" - self._send_one_way_rpc_request( - request=RPCStartupRequest.IS_SERVER_READY, - socket=socket, - await_ack=True, - error_message="Unable to start RPC Server") - - async def _check_health_rpc(self, socket: Socket): - """Get current health status from the RPCServer""" + await self._send_one_way_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, socket=socket) - self._send_one_way_rpc_request( - request=RPCUtilityRequest.CHECK_HEALTH, - socket=socket, - await_ack=True, - error_message="Check health failed.") + await self._await_ack(error_message="Unable to start RPC Server", + socket=socket) async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCStartupRequest.CLIENT_IS_READY, - socket=socket) + request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -377,8 +374,7 @@ async def abort(self, request_id: str): with suppress(MPClientClosedError): await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - socket=self.input_socket) + request=RPCAbortRequest(request_id), socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" @@ -424,7 +420,8 @@ async def generate( raise ENGINE_DEAD_ERROR # 1) Create output queue for this requests. - queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue try: @@ -456,7 +453,7 @@ async def generate( # TODO: check if aborted requests are getting here. self.output_queues.pop(request_id) - + # Request was canceled by the client. if not finished and not self._errored: await self.abort(request_id) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 5fe8002e35e7..12511da16b5e 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,24 +1,22 @@ import pickle from contextlib import contextmanager -from typing import Iterator, List, Optional, Tuple, Type, Union +from typing import Iterator, Union import cloudpickle -import ray import zmq from vllm import AsyncEngineArgs, LLMEngine -from vllm.executor.executor_base import get_executor_cls from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.engine.multiprocessing import (IPC_INPUT_EXT, IPC_OUTPUT_EXT, - IPC_HEALTH_EXT, IPC_DATA_EXT, - VLLM_RPC_SUCCESS_STR, VLLM_RPC_FAILED_STR, - ENGINE_DEAD_ERROR, REQUEST_OUTPUTS_T, - RPCStartupRequest, RPCUtilityRequest, - RPCGenerateRequest, RPCGenerateError, - RPCAbortRequest) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_FAILED_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateError, RPCGenerateRequest, + RPCStartupRequest, RPCUtilityRequest) +from vllm.executor.executor_base import get_executor_cls from vllm.logger import init_logger -from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, @@ -30,6 +28,7 @@ HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) UNHEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_FAILED_STR), ) + class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -59,9 +58,9 @@ class MQLLMEngine: """ def __init__(self, - *args, ipc_path: str, use_async_sockets: bool, + *args, log_requests: bool = True, **kwargs) -> None: self.engine = LLMEngine(*args, **kwargs) @@ -108,14 +107,13 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, executor_class = get_executor_cls(engine_config) return cls( - executor_class.uses_ray, + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - ipc_path=ipc_path, - use_async_sockets=engine_config.model_config.use_async_output_proc) + usage_context=usage_context) def cleanup(self): """Cleanup zeromq state on shutdown.""" @@ -175,7 +173,8 @@ def run_startup_loop(self) -> None: copy=False) def run_engine_loop(self): - """Entrypoint for core busy loop""" + """Entrypoint for core busy loop of the LLMEngine.""" + while True: # Poll until there is work to do. self.poll_for_work() @@ -186,37 +185,33 @@ def run_engine_loop(self): # Engine step. request_outputs = self.engine_step() - # Stream results if neeeded - if (not self.use_async_sockets or - isinstance(request_outputs, RPCGenerateError)): + # Stream results if needed. + if (not self.use_async_sockets + or isinstance(request_outputs, RPCGenerateError)): self._send_request_outputs(request_outputs) - def poll_for_work(self): """Poll the socket until there is work to do.""" if not self.engine.has_unfinished_requests(): while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new requests.") - def engine_step(self) -> REQUEST_OUTPUTS_T: """Engine step wrapper with error handling.""" try: - request_outputs = self.engine.step() + return self.engine.step() except Exception as e: self._errored = True - request_outputs = RPCGenerateError(request_id=None, - is_errored=self._errored, - exception=e) - finally: - return request_outputs + return RPCGenerateError(request_id=None, + is_engine_errored=self._errored, + exception=e) def handle_new_input(self): """Handle new input from the socket""" try: while self.input_socket.poll(timeout=0) != 0: # TODO: do we need error handling around the pickling? - message = self.input_socket.recv(copy=False) + message = self.input_socket.recv(copy=False) request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): @@ -229,8 +224,7 @@ def handle_new_input(self): elif isinstance(request, RPCUtilityRequest): self._handle_utility_request(request) else: - raise ValueError( - "Unknown RPCRequest Type: {request}") + raise ValueError("Unknown RPCRequest Type: {request}") except Exception as e: self._errored = True @@ -259,29 +253,25 @@ def _handle_generate_request(self, request: RPCGenerateRequest): except Exception as err: self.engine.abort_request(request_id) - - # We do not set self._errored = True here, - # since the error is due to an issue adding this - # request to the engine, rather than an issue with - # the engine itself. + + # We do not set self._errored = True here, since the error is + # due to an issue adding this request to the engine, rather + # than an issue with the engine itself. e = RPCGenerateError(request_id, self._errored, err) self._send_request_outputs(e) - def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request(request.request_id) if self.log_requests: logger.info("Aborted request %s.", request.request_id) - - def _handle_utility_request(self, request: RPCUtilityRequest): + def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() self._send_healthy() - def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -290,18 +280,16 @@ def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): output_bytes = pickle.dumps(request_outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) - def _send_healthy(self): """Send HEALTHY message to RPCClient.""" self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - def _send_unhealthy(self): """Send UNHEALTHY message to RPCClient.""" self.health_socket.send_multipart(UNHEALTHY_RESPONSE, copy=False) - - def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): """Callback used by engine to make socket handling async with GPU.""" self._send_request_outputs(request_outputs) self.handle_new_input() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3e4faa80be59..04bb7b907426 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,8 +21,8 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.mp_client import MQLLMEngineClient -from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bf8c82dc6804..68cc12039c2c 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple, Type -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, - LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, +from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoadConfig, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.executor.ray_utils import initialize_ray_cluster from vllm.lora.request import LoRARequest @@ -150,10 +149,9 @@ async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" self.check_health() - -def get_executor_cls( - engine_config: EngineConfig) -> Type["ExecutorAsyncBase"]: + +def get_executor_cls(engine_config: EngineConfig) -> Type["ExecutorAsyncBase"]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) if isinstance(distributed_executor_backend, type): From 55984949d6c48d00e2dac7776722baae28fc4921 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 13:50:27 +0000 Subject: [PATCH 040/116] added error handling --- examples/openai_completion_client.py | 3 +- vllm/engine/multiprocessing/__init__.py | 2 +- vllm/engine/multiprocessing/client.py | 39 +++--- vllm/engine/multiprocessing/engine.py | 113 ++++++++++++------ vllm/engine/protocol.py | 6 +- vllm/entrypoints/launcher.py | 16 ++- vllm/entrypoints/openai/serving_chat.py | 6 + vllm/entrypoints/openai/serving_completion.py | 6 + 8 files changed, 128 insertions(+), 63 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 32d58ec5317a..9ba249471175 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -8,6 +8,7 @@ # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, base_url=openai_api_base, + max_retries=0 ) models = client.models.list() @@ -19,7 +20,7 @@ model=model, prompt="A robot may not injure a human being", stream=stream, - logprobs=3) + max_tokens=1000,) print("Completion results:") if stream: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ffcdd2ad2c92..19a0858a23d3 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -32,7 +32,7 @@ class RPCGenerateRequest: @dataclass -class RPCGenerateError(BaseException): +class RPCError: request_id: Optional[str] is_engine_errored: bool exception: BaseException diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index bc2c27c9331e..0a1135120de3 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -16,7 +16,7 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, RPCGenerateError, + RPCAbortRequest, RPCError, RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS @@ -69,6 +69,7 @@ class MQLLMEngineClient: def __init__(self, ipc_path: str): self.context = zmq.asyncio.Context() self._errored = False + self.dead_error = ENGINE_DEAD_ERROR # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) @@ -128,10 +129,10 @@ async def run_check_health_loop(self, timeout: int): error_message="Health check failed.", socket=self.health_socket) - logger.debug("Health probe complete.") + logger.debug("Health probe successful.") except asyncio.CancelledError: - logger.info("Shutting down MQLLMEngineClient check health loop.") + logger.debug("Shutting down MQLLMEngineClient check health loop.") except Exception as e: logger.exception(repr(e)) @@ -146,21 +147,24 @@ async def run_output_handler_loop(self): request_outputs: REQUEST_OUTPUTS_T = pickle.loads( message.buffer) - if isinstance(request_outputs, BaseException): - if isinstance(request_outputs, RPCGenerateError): - generate_error: RPCGenerateError = request_outputs - request_id = generate_error.request_id - if generate_error.is_engine_errored: + is_error = (isinstance(request_outputs, BaseException) or + isinstance(request_outputs, RPCError)) + + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + if rpc_error.is_engine_errored: self._errored = True - exception = generate_error.exception + exception = rpc_error.exception else: - # MPLLMEngine should always return an RPCGenerateError - # if the error handling is graceful. If we are here, - # we are in a bad state and should shut down the server. + # MPLLMEngine should always return an RPCError + # when an issue arises. If we are here, we are in a + # bad state and should shut down the server. error: BaseException = request_outputs logger.warning( - "Got raw Exception {error} from MPLLMEngine. " - "This should never happen.") + f"Recieved raw Exception {error} rather than " + "RPCError from MPLLMEngine. This should never happen.") self._errored = True request_id = None exception = error @@ -181,13 +185,12 @@ async def run_output_handler_loop(self): queue.put_nowait(request_output) except asyncio.CancelledError: - logger.info("Shutting down MQLLMEngineClient output handler.") + logger.debug("Shutting down MQLLMEngineClient output handler.") async def setup(self): """Setup the client before it starts sending server requests.""" with self.get_data_socket() as socket: - # Wait until server is ready. await self._wait_for_server_rpc(socket) @@ -282,7 +285,7 @@ async def _check_success(self, error_message: str, socket: Socket): response = pickle.loads(frame.buffer) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): + if isinstance(response, BaseException): logger.error(error_message) raise response raise ValueError(error_message) @@ -450,8 +453,6 @@ async def generate( finally: # TODO: check if excepted requests are getting here. - # TODO: check if aborted requests are getting here. - self.output_queues.pop(request_id) # Request was canceled by the client. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 12511da16b5e..790f21b94f54 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import Iterator, Union +from typing import List, Iterator, Union import cloudpickle import zmq @@ -13,10 +13,11 @@ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_FAILED_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCGenerateError, RPCGenerateRequest, + RPCError, RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) from vllm.executor.executor_base import get_executor_cls from vllm.logger import init_logger +from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, @@ -26,7 +27,6 @@ POLLING_TIMEOUT_MS = 10000 HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) -UNHEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_FAILED_STR), ) class MQLLMEngine: @@ -115,6 +115,26 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, log_stats=not engine_args.disable_log_stats, usage_context=usage_context) + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e_core: + try: + logger.exception(repr(e_core)) + if self._errored: + logger.debug("Starting Dead Loop.") + self.run_engine_dead_loop() + except Exception as e_dead_loop: + logger.exception(repr(e_dead_loop)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + self.cleanup() + def cleanup(self): """Cleanup zeromq state on shutdown.""" self.input_socket.close() @@ -134,7 +154,7 @@ def make_data_socket( socket.close(linger=0) def run_startup_loop(self) -> None: - """Loop over startup RPCStatupRequest from RPCClient.""" + """Startup loop for sending data from Engine -> Client.""" with self.make_data_socket() as socket: @@ -173,51 +193,61 @@ def run_startup_loop(self) -> None: copy=False) def run_engine_loop(self): - """Entrypoint for core busy loop of the LLMEngine.""" + """Core busy loop of the LLMEngine.""" while True: # Poll until there is work to do. - self.poll_for_work() + if not self.engine.has_unfinished_requests(): + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for new requests in engine loop.") - # Handle any new data. + # Handle any input from the client. self.handle_new_input() # Engine step. request_outputs = self.engine_step() - # Stream results if needed. - if (not self.use_async_sockets - or isinstance(request_outputs, RPCGenerateError)): + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: self._send_request_outputs(request_outputs) - def poll_for_work(self): - """Poll the socket until there is work to do.""" - if not self.engine.has_unfinished_requests(): + + def run_engine_dead_loop(self): + """Loop for replying to all requests that we are dead.""" + if not self._errored: + raise ValueError("In dead loop, but found _errored=False") + + while True: + # Poll until there is a request while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for new requests.") + logger.debug("Waiting for new requests in dead loop.") + + # Handle any new data, replying with EngineDeadError + self.handle_new_input() + - def engine_step(self) -> REQUEST_OUTPUTS_T: + def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" + try: return self.engine.step() except Exception as e: self._errored = True - return RPCGenerateError(request_id=None, - is_engine_errored=self._errored, - exception=e) + err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + logger.exception(repr(e)) + self._send_request_outputs(err) + raise e def handle_new_input(self): """Handle new input from the socket""" try: while self.input_socket.poll(timeout=0) != 0: - # TODO: do we need error handling around the pickling? message = self.input_socket.recv(copy=False) request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): - # Exceptions in RPCGenerateRequest will be caught - # by the handler, meaning any recoverable exceptions - # to only impact that request (and not crash the server) self._handle_generate_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) @@ -230,13 +260,14 @@ def handle_new_input(self): self._errored = True logger.exception(repr(e)) self._send_unhealthy() + raise e def _handle_generate_request(self, request: RPCGenerateRequest): """Handle RPCGenerateRequest by adding it to the LLMEngine.""" request_id = request.request_id if self._errored: - e = RPCGenerateError(request_id, self._errored, ENGINE_DEAD_ERROR) + e = RPCError(request_id, self._errored, ENGINE_DEAD_ERROR) self._send_request_outputs(e) try: @@ -252,14 +283,15 @@ def _handle_generate_request(self, request: RPCGenerateRequest): logger.info("Added request %s.", request.request_id) except Exception as err: - self.engine.abort_request(request_id) - - # We do not set self._errored = True here, since the error is - # due to an issue adding this request to the engine, rather - # than an issue with the engine itself. - e = RPCGenerateError(request_id, self._errored, err) + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + e = RPCError(request_id, self._errored, err) self._send_request_outputs(e) + # Remove request from the engine. + self.engine.abort_request(request_id) + def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request(request.request_id) if self.log_requests: @@ -267,16 +299,19 @@ def _handle_abort_request(self, request: RPCAbortRequest): def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: - self.engine.do_log_stats() + if not self._errored: + self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: - self.engine.check_health() - self._send_healthy() + if self._errored: + self._send_unhealthy(ENGINE_DEAD_ERROR) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._send_unhealthy(e) def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" - - # TODO: do we need error handling around the pickling? - output_bytes = pickle.dumps(request_outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) @@ -284,9 +319,10 @@ def _send_healthy(self): """Send HEALTHY message to RPCClient.""" self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - def _send_unhealthy(self): + def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - self.health_socket.send_multipart(UNHEALTHY_RESPONSE, copy=False) + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -301,5 +337,4 @@ def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, usage_context=usage_context, ipc_path=ipc_path) - engine.run_startup_loop() - engine.run_engine_loop() + engine.start() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index de6314d53219..07cece6b307c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -15,7 +15,7 @@ @runtime_checkable class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncLLMEngine""" + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -29,6 +29,10 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... + @property + def dead_error(self) -> BaseException: + ... + def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index f4a9c61a431c..935f79f37f0b 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -8,6 +8,7 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -55,7 +56,7 @@ async def dummy_shutdown() -> None: logger.debug( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() @@ -82,7 +83,7 @@ async def runtime_error_handler(_, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -91,3 +92,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 78f355228012..70bcc1f167e7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -103,6 +103,12 @@ async def create_chat_completion( if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.async_engine_client.errored: + raise self.async_engine_client.dead_error try: ( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 34f1200753f8..9b8e38ff08a3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -77,6 +77,12 @@ async def create_completion( error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.async_engine_client.errored: + raise self.async_engine_client.dead_error # Return error for unsupported features. if request.suffix is not None: From 98aaa7d8bdcbedac846b434d3f906d814719b1ed Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 13:53:23 +0000 Subject: [PATCH 041/116] change name --- tests/entrypoints/openai/test_serving_engine.py | 4 ++-- vllm/engine/multiprocessing/client.py | 2 +- vllm/engine/protocol.py | 2 +- vllm/entrypoints/launcher.py | 6 +++--- vllm/entrypoints/openai/api_server.py | 14 +++++++------- vllm/entrypoints/openai/serving_chat.py | 4 ++-- vllm/entrypoints/openai/serving_completion.py | 4 ++-- vllm/entrypoints/openai/serving_embedding.py | 4 ++-- vllm/entrypoints/openai/serving_engine.py | 4 ++-- vllm/entrypoints/openai/serving_tokenization.py | 4 ++-- 10 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 325bc0343428..6d9e620b4af7 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -4,7 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) @@ -18,7 +18,7 @@ async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_engine_client = MagicMock(spec=EngineClient) mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 0a1135120de3..7f602bf76d2b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -44,7 +44,7 @@ class MPClientClosedError(Exception): class MQLLMEngineClient: """A client wrapper for MQLLMEngine that conforms to the - AsyncEngineClient protocol. + EngineClient protocol. MQLLMEngine and MQLLMEngineClient are intended to run in separate processes communicating via zeromq ipc sockets. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 07cece6b307c..70444faa670a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,7 +14,7 @@ @runtime_checkable -class AsyncEngineClient(Protocol): +class EngineClient(Protocol): """Protocol class for Clients to Engine""" @property diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 935f79f37f0b..72a73c7a3c0e 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -9,14 +9,14 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, engine: AsyncEngineClient, +async def serve_http(app: FastAPI, engine: EngineClient, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -61,7 +61,7 @@ async def dummy_shutdown() -> None: def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, - engine: AsyncEngineClient) -> None: + engine: EngineClient) -> None: """Adds handlers for fatal errors that should crash the server""" @app.exception_handler(RuntimeError) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 04bb7b907426..47f8b260ddd4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -23,7 +23,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -54,7 +54,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient +async_engine_client: EngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion @@ -97,7 +97,7 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit @@ -118,9 +118,9 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncLLMEngine Directly - multiprocess using AsyncLLMEngine RPC @@ -162,7 +162,7 @@ async def build_async_engine_client_from_engine_args( logger.info("Multiprocessing frontend to use %s for IPC Path.", ipc_path) - # Build RPCClient, which conforms to AsyncEngineClient Protocol. + # Build RPCClient, which conforms to EngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) mp_engine_client = MQLLMEngineClient(ipc_path) @@ -429,7 +429,7 @@ async def authentication(request: Request, call_next): async def init_app( - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, args: Namespace, ) -> FastAPI: app = build_app(args) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 70bcc1f167e7..d5a9cee33031 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_chat_template, load_chat_template, @@ -44,7 +44,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9b8e38ff08a3..e980c7f9367f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd6..9df0e97534a7 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -8,7 +8,7 @@ from typing_extensions import assert_never from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -71,7 +71,7 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd..5fa9cd0ec9f9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 69a5ad5b62cf..727e05c419ce 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (apply_chat_template, load_chat_template, parse_chat_messages_futures) @@ -27,7 +27,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + async_engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, From ba5ef3892beba4c09d88c47ea43706fb986cada8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 13:54:04 +0000 Subject: [PATCH 042/116] change name --- vllm/entrypoints/openai/api_server.py | 36 +++++++++---------- vllm/entrypoints/openai/serving_chat.py | 14 ++++---- vllm/entrypoints/openai/serving_completion.py | 14 ++++---- vllm/entrypoints/openai/serving_embedding.py | 8 ++--- vllm/entrypoints/openai/serving_engine.py | 6 ++-- .../openai/serving_tokenization.py | 8 ++--- 6 files changed, 43 insertions(+), 43 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 47f8b260ddd4..42630d6060b3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -54,7 +54,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: EngineClient +engine_client: EngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion @@ -85,7 +85,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await async_engine_client.do_log_stats() + await engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) @@ -99,18 +99,18 @@ async def _force_log(): async def build_async_engine_client( args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) # Backend itself still global for the silly lil' health handler - global async_engine_client + global engine_client async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: - async_engine_client = engine # type: ignore[assignment] + engine_client = engine # type: ignore[assignment] yield engine @@ -242,7 +242,7 @@ def mount_metrics(app: FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client.check_health() return Response(status_code=200) @@ -333,14 +333,14 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/start_profile") async def start_profile(): logger.info("Starting profiler...") - await async_engine_client.start_profile() + await engine_client.start_profile() logger.info("Profiler started.") return Response(status_code=200) @router.post("/stop_profile") async def stop_profile(): logger.info("Stopping profiler...") - await async_engine_client.stop_profile() + await engine_client.stop_profile() logger.info("Profiler stopped.") return Response(status_code=200) @@ -429,7 +429,7 @@ async def authentication(request: Request, call_next): async def init_app( - async_engine_client: EngineClient, + engine_client: EngineClient, args: Namespace, ) -> FastAPI: app = build_app(args) @@ -439,7 +439,7 @@ async def init_app( else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() + model_config = await engine_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -452,7 +452,7 @@ async def init_app( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, served_model_names, args.response_role, @@ -464,7 +464,7 @@ async def init_app( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -473,13 +473,13 @@ async def init_app( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + engine_client, model_config, served_model_names, request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -495,16 +495,16 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - async with build_async_engine_client(args) as async_engine_client: + async with build_async_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return - app = await init_app(async_engine_client, args) + app = await init_app(engine_client, args) shutdown_task = await serve_http( app, - engine=async_engine_client, + engine=engine_client, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d5a9cee33031..9b95c77c27a8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -44,7 +44,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: EngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -56,7 +56,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -107,8 +107,8 @@ async def create_chat_completion( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.async_engine_client.errored: - raise self.async_engine_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error try: ( @@ -117,7 +117,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( + tokenizer = await self.engine_client.get_tokenizer( lora_request) conversation, mm_data_future = parse_chat_messages_futures( @@ -200,7 +200,7 @@ async def create_chat_completion( engine_inputs["multi_modal_data"] = mm_data is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + await self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -208,7 +208,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e980c7f9367f..38a3f6a584c7 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: EngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -52,7 +52,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -81,8 +81,8 @@ async def create_completion( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.async_engine_client.errored: - raise self.async_engine_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error # Return error for unsupported features. if request.suffix is not None: @@ -101,7 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( + tokenizer = await self.engine_client.get_tokenizer( lora_request) guided_decode_logits_processor = ( @@ -131,7 +131,7 @@ async def create_completion( prompt_adapter_request=prompt_adapter_request) is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + await self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -139,7 +139,7 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9df0e97534a7..6cb16665d4c0 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: EngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -118,7 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( + tokenizer = await self.engine_client.get_tokenizer( lora_request) pooling_params = request.to_pooling_params() @@ -144,7 +144,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5fa9cd0ec9f9..72f9381abc7d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: EngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -75,7 +75,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -159,7 +159,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 727e05c419ce..ba1a0237d924 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -27,7 +27,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: EngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -35,7 +35,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -64,7 +64,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -121,7 +121,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, From 18b5a949819408630975cb6ee89da527fa44bbb3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:00:15 +0000 Subject: [PATCH 043/116] added dead_error to asyncengine --- vllm/engine/async_llm_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 78cdfa1dedf7..2592921a944b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -698,6 +698,14 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None + @property + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + def set_errored(self, exc: Exception) -> None: self._errored_with = exc From b048961b3d9b1364118b7dd559a2ba484f2d81cf Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:21:02 +0000 Subject: [PATCH 044/116] moved tests under openai --- .../entrypoints/openai/rpc/test_zmq_client.py | 224 +++++++++--------- .../openai}/test_chat_template.py | 0 .../openai}/test_openapi_server_ray.py | 0 vllm/engine/multiprocessing/client.py | 6 +- 4 files changed, 115 insertions(+), 115 deletions(-) rename tests/{async_engine => entrypoints/openai}/test_chat_template.py (100%) rename tests/{async_engine => entrypoints/openai}/test_openapi_server_ray.py (100%) diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py index cafd125c5a59..5bac46f4d1f0 100644 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ b/tests/entrypoints/openai/rpc/test_zmq_client.py @@ -1,120 +1,120 @@ -import asyncio -import tempfile -import unittest -import unittest.mock -import uuid +# import asyncio +# import tempfile +# import unittest +# import unittest.mock +# import uuid -import pytest -import pytest_asyncio +# import pytest +# import pytest_asyncio -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, - RPCClientClosedError) -from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer +# from vllm.engine.async_llm_engine import AsyncLLMEngine +# from vllm.engine.multiprocessing.client import (MQLLMEngineClient, +# MQClientClosedError) +# from vllm.engine.multiprocessing.engine import MQLLMEngine -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest_asyncio.fixture(scope="function") -async def dummy_server(tmp_socket, monkeypatch): - dummy_engine = unittest.mock.AsyncMock() - - def dummy_engine_builder(*args, **kwargs): - return dummy_engine - - with monkeypatch.context() as m: - m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) - server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - try: - yield server - finally: - server_task.cancel() - server.cleanup() - - -@pytest_asyncio.fixture(scope="function") -async def client(tmp_socket): - client = AsyncEngineRPCClient(rpc_path=tmp_socket) - # Sanity check: the server is connected - await client._wait_for_server_rpc() - - try: - yield client - finally: - client.close() - - -@pytest.mark.asyncio -async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server _not_ reply with a model config - m.setattr(dummy_server, "get_config", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # And ensure the task completes anyway - # (client.setup() invokes server.get_config()) - client_task = asyncio.get_running_loop().create_task(client.setup()) - with pytest.raises(TimeoutError, match="Server didn't reply within"): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Hang all abort requests - m.setattr(dummy_server, "abort", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # The client should suppress timeouts on `abort`s - # and return normally, assuming the server will eventually - # abort the request. - client_task = asyncio.get_running_loop().create_task( - client.abort("test request id")) - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_data_methods_reraise_exceptions( - monkeypatch, dummy_server, client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server raise some random exception - exception = RuntimeError("Client test exception") - - def raiser(): - raise exception - - m.setattr(dummy_server.engine, "get_model_config", raiser) - m.setattr(client, "_data_timeout", 10) - - client_task = asyncio.get_running_loop().create_task(client.setup()) - # And ensure the task completes, raising the exception - with pytest.raises(RuntimeError, match=str(exception)): - await asyncio.wait_for(client_task, timeout=0.05) +# @pytest.fixture(scope="function") +# def tmp_socket(): +# with tempfile.TemporaryDirectory() as td: +# yield f"ipc://{td}/{uuid.uuid4()}" + + +# @pytest_asyncio.fixture(scope="function") +# async def dummy_server(tmp_socket, monkeypatch): +# dummy_engine = unittest.mock.AsyncMock() + +# def dummy_engine_builder(*args, **kwargs): +# return dummy_engine + +# with monkeypatch.context() as m: +# m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) +# server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) + +# loop = asyncio.get_running_loop() +# server_task = loop.create_task(server.run_server_loop()) + +# try: +# yield server +# finally: +# server_task.cancel() +# server.cleanup() + + +# @pytest_asyncio.fixture(scope="function") +# async def client(tmp_socket): +# client = AsyncEngineRPCClient(rpc_path=tmp_socket) +# # Sanity check: the server is connected +# await client._wait_for_server_rpc() + +# try: +# yield client +# finally: +# client.close() + + +# @pytest.mark.asyncio +# async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, +# client: AsyncEngineRPCClient): +# with monkeypatch.context() as m: +# # Make the server _not_ reply with a model config +# m.setattr(dummy_server, "get_config", lambda x: None) +# m.setattr(client, "_data_timeout", 10) + +# # And ensure the task completes anyway +# # (client.setup() invokes server.get_config()) +# client_task = asyncio.get_running_loop().create_task(client.setup()) +# with pytest.raises(TimeoutError, match="Server didn't reply within"): +# await asyncio.wait_for(client_task, timeout=0.05) + + +# @pytest.mark.asyncio +# async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, +# client: AsyncEngineRPCClient): +# with monkeypatch.context() as m: +# # Hang all abort requests +# m.setattr(dummy_server, "abort", lambda x: None) +# m.setattr(client, "_data_timeout", 10) + +# # The client should suppress timeouts on `abort`s +# # and return normally, assuming the server will eventually +# # abort the request. +# client_task = asyncio.get_running_loop().create_task( +# client.abort("test request id")) +# await asyncio.wait_for(client_task, timeout=0.05) + + +# @pytest.mark.asyncio +# async def test_client_data_methods_reraise_exceptions( +# monkeypatch, dummy_server, client: AsyncEngineRPCClient): +# with monkeypatch.context() as m: +# # Make the server raise some random exception +# exception = RuntimeError("Client test exception") + +# def raiser(): +# raise exception + +# m.setattr(dummy_server.engine, "get_model_config", raiser) +# m.setattr(client, "_data_timeout", 10) + +# client_task = asyncio.get_running_loop().create_task(client.setup()) +# # And ensure the task completes, raising the exception +# with pytest.raises(RuntimeError, match=str(exception)): +# await asyncio.wait_for(client_task, timeout=0.05) -@pytest.mark.asyncio -async def test_client_errors_after_closing(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): +# @pytest.mark.asyncio +# async def test_client_errors_after_closing(monkeypatch, dummy_server, +# client: AsyncEngineRPCClient): - client.close() +# client.close() - # Healthchecks and generate requests will fail with explicit errors - with pytest.raises(RPCClientClosedError): - await client.check_health() - with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): - pass - - # But no-ops like aborting will pass - await client.abort("test-request-id") - await client.do_log_stats() +# # Healthchecks and generate requests will fail with explicit errors +# with pytest.raises(RPCClientClosedError): +# await client.check_health() +# with pytest.raises(RPCClientClosedError): +# async for _ in client.generate(None, None, None): +# pass + +# # But no-ops like aborting will pass +# await client.abort("test-request-id") +# await client.do_log_stats() diff --git a/tests/async_engine/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py similarity index 100% rename from tests/async_engine/test_chat_template.py rename to tests/entrypoints/openai/test_chat_template.py diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/entrypoints/openai/test_openapi_server_ray.py similarity index 100% rename from tests/async_engine/test_openapi_server_ray.py rename to tests/entrypoints/openai/test_openapi_server_ray.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7f602bf76d2b..8e58729825c1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,7 +31,7 @@ logger = init_logger(__name__) -class MPClientClosedError(Exception): +class MQClientClosedError(Exception): """Exception class raised when the client is used post-close. The client can be closed, which closes the ZMQ context. This normally @@ -375,13 +375,13 @@ async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - with suppress(MPClientClosedError): + with suppress(MQClientClosedError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(MPClientClosedError): + with suppress(MQClientClosedError): await self._send_one_way_rpc_request( request=RPCUtilityRequest.DO_LOG_STATS, socket=self.input_socket) From 6b2e18b89a2442506a94c1964e709efa375cd8c0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:23:51 +0000 Subject: [PATCH 045/116] updated tests --- tests/entrypoints/openai/rpc/__init__.py | 0 .../entrypoints/openai/rpc/test_zmq_client.py | 120 ------------------ .../entrypoints/openai/test_chat_template.py | 2 +- .../openai/test_openapi_server_ray.py | 2 +- 4 files changed, 2 insertions(+), 122 deletions(-) delete mode 100644 tests/entrypoints/openai/rpc/__init__.py delete mode 100644 tests/entrypoints/openai/rpc/test_zmq_client.py diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/entrypoints/openai/rpc/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py deleted file mode 100644 index 5bac46f4d1f0..000000000000 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ /dev/null @@ -1,120 +0,0 @@ -# import asyncio -# import tempfile -# import unittest -# import unittest.mock -# import uuid - -# import pytest -# import pytest_asyncio - -# from vllm.engine.async_llm_engine import AsyncLLMEngine -# from vllm.engine.multiprocessing.client import (MQLLMEngineClient, -# MQClientClosedError) -# from vllm.engine.multiprocessing.engine import MQLLMEngine - - -# @pytest.fixture(scope="function") -# def tmp_socket(): -# with tempfile.TemporaryDirectory() as td: -# yield f"ipc://{td}/{uuid.uuid4()}" - - -# @pytest_asyncio.fixture(scope="function") -# async def dummy_server(tmp_socket, monkeypatch): -# dummy_engine = unittest.mock.AsyncMock() - -# def dummy_engine_builder(*args, **kwargs): -# return dummy_engine - -# with monkeypatch.context() as m: -# m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) -# server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - -# loop = asyncio.get_running_loop() -# server_task = loop.create_task(server.run_server_loop()) - -# try: -# yield server -# finally: -# server_task.cancel() -# server.cleanup() - - -# @pytest_asyncio.fixture(scope="function") -# async def client(tmp_socket): -# client = AsyncEngineRPCClient(rpc_path=tmp_socket) -# # Sanity check: the server is connected -# await client._wait_for_server_rpc() - -# try: -# yield client -# finally: -# client.close() - - -# @pytest.mark.asyncio -# async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, -# client: AsyncEngineRPCClient): -# with monkeypatch.context() as m: -# # Make the server _not_ reply with a model config -# m.setattr(dummy_server, "get_config", lambda x: None) -# m.setattr(client, "_data_timeout", 10) - -# # And ensure the task completes anyway -# # (client.setup() invokes server.get_config()) -# client_task = asyncio.get_running_loop().create_task(client.setup()) -# with pytest.raises(TimeoutError, match="Server didn't reply within"): -# await asyncio.wait_for(client_task, timeout=0.05) - - -# @pytest.mark.asyncio -# async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, -# client: AsyncEngineRPCClient): -# with monkeypatch.context() as m: -# # Hang all abort requests -# m.setattr(dummy_server, "abort", lambda x: None) -# m.setattr(client, "_data_timeout", 10) - -# # The client should suppress timeouts on `abort`s -# # and return normally, assuming the server will eventually -# # abort the request. -# client_task = asyncio.get_running_loop().create_task( -# client.abort("test request id")) -# await asyncio.wait_for(client_task, timeout=0.05) - - -# @pytest.mark.asyncio -# async def test_client_data_methods_reraise_exceptions( -# monkeypatch, dummy_server, client: AsyncEngineRPCClient): -# with monkeypatch.context() as m: -# # Make the server raise some random exception -# exception = RuntimeError("Client test exception") - -# def raiser(): -# raise exception - -# m.setattr(dummy_server.engine, "get_model_config", raiser) -# m.setattr(client, "_data_timeout", 10) - -# client_task = asyncio.get_running_loop().create_task(client.setup()) -# # And ensure the task completes, raising the exception -# with pytest.raises(RuntimeError, match=str(exception)): -# await asyncio.wait_for(client_task, timeout=0.05) - - -# @pytest.mark.asyncio -# async def test_client_errors_after_closing(monkeypatch, dummy_server, -# client: AsyncEngineRPCClient): - -# client.close() - -# # Healthchecks and generate requests will fail with explicit errors -# with pytest.raises(RPCClientClosedError): -# await client.check_health() -# with pytest.raises(RPCClientClosedError): -# async for _ in client.generate(None, None, None): -# pass - -# # But no-ops like aborting will pass -# await client.abort("test-request-id") -# await client.do_log_stats() diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 4df6c0297328..6e33bb2e5a72 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH +from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_openapi_server_ray.py b/tests/entrypoints/openai/test_openapi_server_ray.py index f70118546c7b..1c5f645be121 100644 --- a/tests/entrypoints/openai/test_openapi_server_ray.py +++ b/tests/entrypoints/openai/test_openapi_server_ray.py @@ -2,7 +2,7 @@ import pytest import pytest_asyncio -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" From 7a7ff5b1178487e5198371f743d2075f72cbe414 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:35:30 +0000 Subject: [PATCH 046/116] revert executor change --- vllm/engine/async_llm_engine.py | 68 ++++++++++++++++++++++++++++++++- vllm/executor/executor_base.py | 62 ------------------------------ 2 files changed, 66 insertions(+), 64 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2592921a944b..df38a86948ee 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,7 @@ from typing_extensions import assert_never import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs @@ -15,7 +15,8 @@ from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, PromptComponents, SchedulerOutputState) from vllm.engine.metrics_types import StatLoggerBase -from vllm.executor.executor_base import get_executor_cls +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import ray from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) @@ -649,6 +650,69 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + + @classmethod + def _get_executor_cls( + cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorAsyncBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorAsyncBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutorAsync + executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync + executor_class = RayTPUExecutorAsync + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutorAsync + executor_class = TPUExecutorAsync + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "openvino": + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with " + "the OpenVINO backend.") + from vllm.executor.openvino_executor import OpenVINOExecutorAsync + executor_class = OpenVINOExecutorAsync + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend is None: + from vllm.executor.xpu_executor import XPUExecutorAsync + executor_class = XPUExecutorAsync + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync + executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync + else: + raise RuntimeError( + "Not supported distributed execution model on XPU device.") + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync + executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync + else: + from vllm.executor.gpu_executor import GPUExecutorAsync + executor_class = GPUExecutorAsync + return executor_class @classmethod def from_engine_args( diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 68cc12039c2c..01fd3bb6279b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -149,65 +149,3 @@ async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" self.check_health() - - -def get_executor_cls(engine_config: EngineConfig) -> Type["ExecutorAsyncBase"]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorAsyncBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutorAsync - executor_class = NeuronExecutorAsync - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync - executor_class = RayTPUExecutorAsync - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutorAsync - executor_class = TPUExecutorAsync - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutorAsync - executor_class = CPUExecutorAsync - elif engine_config.device_config.device_type == "openvino": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with " - "the OpenVINO backend.") - from vllm.executor.openvino_executor import OpenVINOExecutorAsync - executor_class = OpenVINOExecutorAsync - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend is None: - from vllm.executor.xpu_executor import XPUExecutorAsync - executor_class = XPUExecutorAsync - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync - executor_class = RayXPUExecutorAsync - elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.multiproc_xpu_executor import ( - MultiprocessingXPUExecutorAsync) - executor_class = MultiprocessingXPUExecutorAsync - else: - raise RuntimeError( - "Not supported distributed execution model on XPU device.") - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - executor_class = RayGPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutorAsync) - executor_class = MultiprocessingGPUExecutorAsync - else: - from vllm.executor.gpu_executor import GPUExecutorAsync - executor_class = GPUExecutorAsync - return executor_class From b7e1fe995a936781b10fe4c958a7063686f90d4d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:36:28 +0000 Subject: [PATCH 047/116] revert --- vllm/engine/async_llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index df38a86948ee..a207587cc20d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -650,7 +650,7 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker - + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: @@ -730,7 +730,7 @@ def from_engine_args( from vllm.executor import ray_utils ray_utils.assert_ray_available() - executor_class = get_executor_cls(engine_config) + executor_class = cls._get_executor_cls(engine_config) # Create the async LLM engine. engine = cls( From 48068d596559ea4ae2a63ef731136b80e65bdaf0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:37:25 +0000 Subject: [PATCH 048/116] executor class --- vllm/engine/multiprocessing/engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 790f21b94f54..87dc1ea863be 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -15,7 +15,6 @@ VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) -from vllm.executor.executor_base import get_executor_cls from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -104,7 +103,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, "Launch with --disable-frontend-multiprocessing if you " "need to deploy with this flag (not recommended).") - executor_class = get_executor_cls(engine_config) + executor_class = LLMEngine._get_executor_cls(engine_config) return cls( ipc_path=ipc_path, From e3daa28d2afa2c7ad906b0dc2ed4b52671523013 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 14:52:29 +0000 Subject: [PATCH 049/116] cleanup format --- examples/openai_completion_client.py | 6 +-- vllm/engine/async_llm_engine.py | 1 - vllm/engine/multiprocessing/__init__.py | 3 +- vllm/engine/multiprocessing/client.py | 25 +++++------ vllm/engine/multiprocessing/engine.py | 45 ++++++++++--------- vllm/entrypoints/launcher.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 11 +++-- vllm/entrypoints/openai/serving_completion.py | 11 +++-- vllm/entrypoints/openai/serving_embedding.py | 3 +- vllm/executor/executor_base.py | 9 ++-- 10 files changed, 54 insertions(+), 62 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 9ba249471175..5804345603a4 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -8,8 +8,7 @@ # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, base_url=openai_api_base, - max_retries=0 -) + max_retries=0) models = client.models.list() model = models.data[0].id @@ -20,7 +19,8 @@ model=model, prompt="A robot may not injure a human being", stream=stream, - max_tokens=1000,) + max_tokens=1000, +) print("Completion results:") if stream: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a207587cc20d..644c6fb96c29 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,6 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray -from vllm.executor.ray_utils import ray from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 19a0858a23d3..9835abd63237 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -9,7 +9,6 @@ from vllm.sampling_params import SamplingParams VLLM_RPC_SUCCESS_STR = "SUCCESS" -VLLM_RPC_FAILED_STR = "FAILED" IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" @@ -62,7 +61,7 @@ class RPCStartupRequest(Enum): RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest, RPCStartupRequest] -REQUEST_OUTPUTS_T = Union[List[RequestOutput], BaseException] +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] ENGINE_DEAD_ERROR = MQEngineDeadError( "Engine loop is not running. Inspect the output to find " diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8e58729825c1..69f20afc08ae 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -14,11 +14,10 @@ ParallelConfig, SchedulerConfig) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, RPCError, - RPCGenerateRequest, RPCStartupRequest, - RPCUtilityRequest) + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCStartupRequest, RPCUtilityRequest) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -144,12 +143,10 @@ async def run_output_handler_loop(self): try: while True: message: Frame = await self.output_socket.recv(copy=False) - request_outputs: REQUEST_OUTPUTS_T = pickle.loads( - message.buffer) + request_outputs = pickle.loads(message.buffer) - is_error = (isinstance(request_outputs, BaseException) or - isinstance(request_outputs, RPCError)) - + is_error = isinstance(request_outputs, + (BaseException, RPCError)) if is_error: if isinstance(request_outputs, RPCError): rpc_error: RPCError = request_outputs @@ -159,12 +156,12 @@ async def run_output_handler_loop(self): exception = rpc_error.exception else: # MPLLMEngine should always return an RPCError - # when an issue arises. If we are here, we are in a + # when an issue arises. If we are here, we are in a # bad state and should shut down the server. error: BaseException = request_outputs - logger.warning( - f"Recieved raw Exception {error} rather than " - "RPCError from MPLLMEngine. This should never happen.") + logger.error( + "Received raw Exception %s rather than RPCError " + "from MPLLMEngine. This should never happen.", error) self._errored = True request_id = None exception = error diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 87dc1ea863be..ee374c394d07 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import List, Iterator, Union +from typing import Iterator, List, Union import cloudpickle import zmq @@ -11,7 +11,6 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_FAILED_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) @@ -208,23 +207,21 @@ def run_engine_loop(self): # Send request outputs (if async, done in engine_step callback). if not self.use_async_sockets: - self._send_request_outputs(request_outputs) - + self._send_outputs(request_outputs) def run_engine_dead_loop(self): """Loop for replying to all requests that we are dead.""" if not self._errored: raise ValueError("In dead loop, but found _errored=False") - + while True: # Poll until there is a request while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new requests in dead loop.") - + # Handle any new data, replying with EngineDeadError self.handle_new_input() - def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" @@ -232,11 +229,11 @@ def engine_step(self) -> List[RequestOutput]: return self.engine.step() except Exception as e: self._errored = True - err = RPCError(request_id=None, - is_engine_errored=True, - exception=e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) logger.exception(repr(e)) - self._send_request_outputs(err) + self._send_outputs(rpc_err) raise e def handle_new_input(self): @@ -258,7 +255,7 @@ def handle_new_input(self): except Exception as e: self._errored = True logger.exception(repr(e)) - self._send_unhealthy() + self._send_unhealthy(e) raise e def _handle_generate_request(self, request: RPCGenerateRequest): @@ -266,8 +263,10 @@ def _handle_generate_request(self, request: RPCGenerateRequest): request_id = request.request_id if self._errored: - e = RPCError(request_id, self._errored, ENGINE_DEAD_ERROR) - self._send_request_outputs(e) + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR) + self._send_outputs(rpc_err) try: self.engine.add_request( @@ -281,12 +280,14 @@ def _handle_generate_request(self, request: RPCGenerateRequest): if self.log_requests: logger.info("Added request %s.", request.request_id) - except Exception as err: - # We do not set self._errored = True here, since the error - # is due to an issue adding this request to the engine, + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, # rather than an issue with the engine itself. - e = RPCError(request_id, self._errored, err) - self._send_request_outputs(e) + rpc_err = RPCError(request_id=request_id, + is_engine_errored=self._errored, + exception=e) + self._send_outputs(rpc_err) # Remove request from the engine. self.engine.abort_request(request_id) @@ -309,9 +310,9 @@ def _handle_utility_request(self, request: RPCUtilityRequest): except Exception as e: self._send_unhealthy(e) - def _send_request_outputs(self, request_outputs: REQUEST_OUTPUTS_T): + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" - output_bytes = pickle.dumps(request_outputs) + output_bytes = pickle.dumps(outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) def _send_healthy(self): @@ -326,7 +327,7 @@ def _send_unhealthy(self, error: BaseException): def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): """Callback used by engine to make socket handling async with GPU.""" - self._send_request_outputs(request_outputs) + self._send_outputs(request_outputs) self.handle_new_input() diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 72a73c7a3c0e..017ce4e013f0 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -92,7 +92,7 @@ async def async_engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - + @app.exception_handler(MQEngineDeadError) async def mq_engine_dead_handler(_, __): """Kill the server if the mq engine is already dead. It will diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9b95c77c27a8..77b6a4c917e7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -103,9 +103,9 @@ async def create_chat_completion( if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret - + # If the engine is dead, raise the engine's DEAD_ERROR. - # This is required for the streaming case, where we return a + # This is required for the streaming case, where we return a # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error @@ -117,8 +117,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -199,8 +198,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = ( - await self.engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 38a3f6a584c7..3f36867a2960 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -77,9 +77,9 @@ async def create_completion( error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - + # If the engine is dead, raise the engine's DEAD_ERROR. - # This is required for the streaming case, where we return a + # This is required for the streaming case, where we return a # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error @@ -101,8 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -130,8 +129,8 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = ( - await self.engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6cb16665d4c0..f111a3a8277b 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -118,8 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 01fd3bb6279b..c96cb0f2c298 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple, Type +from typing import List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoadConfig, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) -from vllm.executor.ray_utils import initialize_ray_cluster from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest From 7880b7592a59d6cb0206244e6eb420c0538c12f3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 15:05:10 +0000 Subject: [PATCH 050/116] format --- vllm/engine/multiprocessing/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 69f20afc08ae..8a6b5fed7149 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -160,8 +160,8 @@ async def run_output_handler_loop(self): # bad state and should shut down the server. error: BaseException = request_outputs logger.error( - "Received raw Exception %s rather than RPCError " - "from MPLLMEngine. This should never happen.", error) + "Received raw Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) self._errored = True request_id = None exception = error From 29fe3c8561dd1e28000f0a10fb368a64b0714899 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 15:05:49 +0000 Subject: [PATCH 051/116] shorten --- vllm/engine/multiprocessing/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8a6b5fed7149..497376a24205 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -160,7 +160,7 @@ async def run_output_handler_loop(self): # bad state and should shut down the server. error: BaseException = request_outputs logger.error( - "Received raw Exception %s rather than RPCError from " + "Received Exception %s rather than RPCError from " "MPLLMEngine. This should never happen.", error) self._errored = True request_id = None From a72094789b73daf0c6c5df3d360b24453ac9452b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 15:09:04 +0000 Subject: [PATCH 052/116] Revert change --- examples/openai_completion_client.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 5804345603a4..58519f978d34 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -2,25 +2,26 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8001/v1" +openai_api_base = "http://localhost:8000/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, base_url=openai_api_base, - max_retries=0) +) models = client.models.list() model = models.data[0].id # Completion API -stream = True +stream = False completion = client.completions.create( model=model, prompt="A robot may not injure a human being", + echo=False, + n=2, stream=stream, - max_tokens=1000, -) + logprobs=3) print("Completion results:") if stream: From 5b8cee66b7434718a5519c700232078987b4d3d1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 16:08:25 +0000 Subject: [PATCH 053/116] enable shutdown for tp>1 --- vllm/engine/llm_engine.py | 1 + vllm/entrypoints/openai/api_server.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 78ddcd1daaf6..a5c0fb992ab4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1612,6 +1612,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 42630d6060b3..f8664e01c830 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -193,13 +193,16 @@ async def build_async_engine_client_from_engine_args( yield mp_engine_client # type: ignore[misc] finally: - # Ensure rpc server process was terminated - engine_process.terminate() + # Shutdown engine process + # NOTE: terminate() (which sends SIGTERM), does not work here + # when tp>1. TODO: discuss with @njhill how we can have cleaner + # shutdown with terminate() rather than kill() + engine_process.kill() # Close all open connections to the backend mp_engine_client.close() - # Wait for server process to join + # Wait for engine process to join engine_process.join() # Lazy import for prometheus multiprocessing. From 97a241deeca569ea6b41eae30b460970483e321f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 16:16:34 +0000 Subject: [PATCH 054/116] format --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f8664e01c830..99bfdfaa4717 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -193,7 +193,7 @@ async def build_async_engine_client_from_engine_args( yield mp_engine_client # type: ignore[misc] finally: - # Shutdown engine process + # Shutdown engine process # NOTE: terminate() (which sends SIGTERM), does not work here # when tp>1. TODO: discuss with @njhill how we can have cleaner # shutdown with terminate() rather than kill() From 6d0570e516ffdebef7e832509bcc55cf6b965456 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 16:42:07 +0000 Subject: [PATCH 055/116] added error handling --- tests/mq_llm_engine/test_error_handling.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/mq_llm_engine/test_error_handling.py diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py new file mode 100644 index 000000000000..a346b76648f9 --- /dev/null +++ b/tests/mq_llm_engine/test_error_handling.py @@ -0,0 +1,20 @@ +import asyncio +import pytest +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) + +from vllm.engine.arg_utils import AsyncEngineArgs + +@pytest.mark.asyncio(scope="module") +async def test_bad_startup(): + bad_engine_args = AsyncEngineArgs(model="Qwen/Qwen2-0.5B-Instruct", + tensor_parallel_size=1234) + + async with asyncio.timeout(60.): + async with build_async_engine_client_from_engine_args( + bad_engine_args, False) as llm: + assert llm is None + + + + From eb267917e6123bf8e4e1d07ee96c42a666ca6da4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 8 Sep 2024 18:34:15 +0000 Subject: [PATCH 056/116] format --- tests/mq_llm_engine/test_error_handling.py | 20 -------------------- vllm/entrypoints/openai/api_server.py | 2 +- 2 files changed, 1 insertion(+), 21 deletions(-) delete mode 100644 tests/mq_llm_engine/test_error_handling.py diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py deleted file mode 100644 index a346b76648f9..000000000000 --- a/tests/mq_llm_engine/test_error_handling.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio -import pytest -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) - -from vllm.engine.arg_utils import AsyncEngineArgs - -@pytest.mark.asyncio(scope="module") -async def test_bad_startup(): - bad_engine_args = AsyncEngineArgs(model="Qwen/Qwen2-0.5B-Instruct", - tensor_parallel_size=1234) - - async with asyncio.timeout(60.): - async with build_async_engine_client_from_engine_args( - bad_engine_args, False) as llm: - assert llm is None - - - - diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 99bfdfaa4717..783e2951b00c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -131,7 +131,7 @@ async def build_async_engine_client_from_engine_args( # TODO: support embedding model via RPC. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, engine_args.quantization) - or disable_frontend_multiprocessing): + or engine_args.engine_use_ray or disable_frontend_multiprocessing): engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) try: From e25605030a1e71f8c291cb3b499e9f1d41f7da66 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 9 Sep 2024 00:02:51 +0000 Subject: [PATCH 057/116] try out hwm --- vllm/engine/multiprocessing/client.py | 2 ++ vllm/engine/multiprocessing/engine.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 497376a24205..b7fd77b6a2b7 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -72,10 +72,12 @@ def __init__(self, ipc_path: str): # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + # self.input_socket.set_hwm(0) self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") # Receive streams of RequestOutput from the MQLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + # self.output_socket.set_hwm(0) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") # IPC path for ack of check_health requests. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index ee374c394d07..77923c32d2b0 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -73,11 +73,14 @@ def __init__(self, # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) + # self.input_socket.set_hwm(0) self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") # Send output stream back to client. self.output_socket = self.ctx.socket(zmq.constants.PUSH) + # self.output_socket.set_hwm(0) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + # Send health status back to client. self.health_socket = self.ctx.socket(zmq.constants.PUSH) From 59c5aca54f5ea20f90dd64744a126724d17ea9ee Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 9 Sep 2024 15:03:36 -0700 Subject: [PATCH 058/116] Add stop_remote_worker_execution_loop for TP case --- vllm/engine/multiprocessing/engine.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 77923c32d2b0..6f75fd5c90ac 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -80,7 +80,6 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) # self.output_socket.set_hwm(0) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send health status back to client. self.health_socket = self.ctx.socket(zmq.constants.PUSH) @@ -197,8 +196,13 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: - # Poll until there is work to do. - if not self.engine.has_unfinished_requests(): + if not self.engine.has_unfinished_requests() and ( + self.input_socket.poll(timeout=0) == 0): + + # Stop remote worker loop in distributed case. + self.engine.stop_remote_worker_execution_loop() + + # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new requests in engine loop.") @@ -315,8 +319,9 @@ def _handle_utility_request(self, request: RPCUtilityRequest): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" - output_bytes = pickle.dumps(outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) def _send_healthy(self): """Send HEALTHY message to RPCClient.""" From 62f654aed2b02679ece7552c8a3b0d7db925da67 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 07:03:06 -0700 Subject: [PATCH 059/116] Revert unnecessary stop_remote_worker_execution_loop --- vllm/engine/multiprocessing/engine.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 6f75fd5c90ac..df0020d98146 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -196,12 +196,7 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: - if not self.engine.has_unfinished_requests() and ( - self.input_socket.poll(timeout=0) == 0): - - # Stop remote worker loop in distributed case. - self.engine.stop_remote_worker_execution_loop() - + if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new requests in engine loop.") From 75c6157ed68dc3d23a4540bf64d0413f0e8e9ff3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 16:20:11 +0000 Subject: [PATCH 060/116] fixed magicmock errored --- tests/entrypoints/openai/test_serving_chat.py | 10 +++++++--- vllm/entrypoints/openai/serving_chat.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d9..5e16f71e7bfe 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from vllm.config import MultiModalConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer @@ -52,8 +52,9 @@ def test_async_serving_chat_init(): def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), @@ -74,7 +75,10 @@ def test_serving_chat_should_set_correct_max_tokens(): with suppress(Exception): asyncio.run(serving_chat.create_chat_completion(req)) - + + print(mock_engine) + print(mock_engine.generate) + print(mock_engine.generate.call_args) assert mock_engine.generate.call_args.args[1].max_tokens == 93 req.max_tokens = 10 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 77b6a4c917e7..9ed5d445f4f0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -109,7 +109,7 @@ async def create_chat_completion( # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error - + try: ( lora_request, @@ -207,6 +207,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() + print("calling generate") result_generator = self.engine_client.generate( engine_inputs, sampling_params, From 370c104bad6bcff3b219a1a41011f72229073346 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 16:46:03 +0000 Subject: [PATCH 061/116] fall back to asyncllmengine if pp --- vllm/engine/multiprocessing/client.py | 15 +++++++++++++++ vllm/entrypoints/openai/api_server.py | 21 ++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b7fd77b6a2b7..30dfc9124359 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -12,6 +12,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, @@ -94,6 +95,20 @@ def __init__(self, ipc_path: str): # Loop to check health of the LLMEngine periodically. self.health_loop: Optional[asyncio.Task] = None + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + is_embedding = ModelConfig( + model=engine_args.model, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + is_pp = engine_args.pipeline_parallel_size > 1 + is_engine_use_ray = engine_args.engine_use_ray + return is_embedding or is_pp or is_engine_use_ray + @contextmanager def get_data_socket(self) -> Iterator[Socket]: socket = self.context.socket(zmq.constants.DEALER) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 783e2951b00c..7516cf359034 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -67,18 +67,6 @@ _running_tasks: Set[asyncio.Task] = set() - -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str]) -> bool: - return ModelConfig(model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): @@ -127,11 +115,10 @@ async def build_async_engine_client_from_engine_args( Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) - or engine_args.engine_use_ray or disable_frontend_multiprocessing): + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) + or disable_frontend_multiprocessing): engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) try: From 0cf9551677ea841c594cd49a23d510e44546fd7b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 16:52:29 +0000 Subject: [PATCH 062/116] formatting --- tests/entrypoints/openai/test_serving_chat.py | 2 +- vllm/entrypoints/openai/api_server.py | 4 ++-- vllm/entrypoints/openai/serving_chat.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 5e16f71e7bfe..f2a4e3e27cd3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -75,7 +75,7 @@ def test_serving_chat_should_set_correct_max_tokens(): with suppress(Exception): asyncio.run(serving_chat.create_chat_completion(req)) - + print(mock_engine) print(mock_engine.generate) print(mock_engine.generate.call_args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7516cf359034..2c0244709f5b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,7 +18,6 @@ from typing_extensions import assert_never import vllm.envs as envs -from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -67,6 +66,7 @@ _running_tasks: Set[asyncio.Task] = set() + @asynccontextmanager async def lifespan(app: FastAPI): @@ -118,7 +118,7 @@ async def build_async_engine_client_from_engine_args( # Fall back # TODO: fill out feature matrix. if (MQLLMEngineClient.is_unsupported_config(engine_args) - or disable_frontend_multiprocessing): + or disable_frontend_multiprocessing): engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) try: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 01d2b7529596..0909381dd750 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -109,7 +109,7 @@ async def create_chat_completion( # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error - + try: ( lora_request, From 72f72fd16377e6a350b3ae97694abd3c69c73cbd Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 17:38:33 +0000 Subject: [PATCH 063/116] stash --- vllm/engine/multiprocessing/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index df0020d98146..9ed651e39bef 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -339,5 +339,4 @@ def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, engine = MQLLMEngine.from_engine_args(engine_args=engine_args, usage_context=usage_context, ipc_path=ipc_path) - engine.start() From 364ed7f2586a103881f9686632dcc3e96745fd1d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 18:48:06 +0000 Subject: [PATCH 064/116] remove DO_LOG_STATS RPC call --- vllm/engine/multiprocessing/__init__.py | 7 +++---- vllm/engine/multiprocessing/client.py | 11 ++++------ vllm/engine/multiprocessing/engine.py | 27 +++++++++++-------------- vllm/entrypoints/openai/api_server.py | 2 +- 4 files changed, 20 insertions(+), 27 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 9835abd63237..f028276e1931 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -42,9 +42,8 @@ class RPCAbortRequest: request_id: str -class RPCUtilityRequest(Enum): - DO_LOG_STATS = 1 - CHECK_HEALTH = 2 +class RPCHealthRequest: + pass class RPCStartupRequest(Enum): @@ -58,7 +57,7 @@ class RPCStartupRequest(Enum): CLIENT_IS_READY = 8 -RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest, +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 30dfc9124359..9ade3c3b8fba 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -18,7 +18,7 @@ IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, - RPCStartupRequest, RPCUtilityRequest) + RPCHealthRequest, RPCStartupRequest) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -134,7 +134,7 @@ async def run_check_health_loop(self, timeout: int): if await self.health_socket.poll(timeout=timeout) == 0: # Wakeup every N seconds and do a health probe. await self._send_one_way_rpc_request( - RPCUtilityRequest.CHECK_HEALTH, self.input_socket) + RPCHealthRequest(), self.input_socket) # Wait for ack from the health socket. await self._await_ack(error_message="Health check failed.", @@ -394,11 +394,8 @@ async def abort(self, request_id: str): request=RPCAbortRequest(request_id), socket=self.input_socket) async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(MQClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - socket=self.input_socket) + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass async def check_health(self): """ diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 9ed651e39bef..0baaa2043f08 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -13,7 +13,7 @@ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, - RPCStartupRequest, RPCUtilityRequest) + RPCHealthRequest, RPCStartupRequest) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -199,6 +199,7 @@ def run_engine_loop(self): if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") # Handle any input from the client. @@ -249,8 +250,8 @@ def handle_new_input(self): self._handle_generate_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) - elif isinstance(request, RPCUtilityRequest): - self._handle_utility_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() else: raise ValueError("Unknown RPCRequest Type: {request}") @@ -299,18 +300,14 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _handle_utility_request(self, request: RPCUtilityRequest): - if request == RPCUtilityRequest.DO_LOG_STATS: - if not self._errored: - self.engine.do_log_stats() - elif request == RPCUtilityRequest.CHECK_HEALTH: - if self._errored: - self._send_unhealthy(ENGINE_DEAD_ERROR) - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._send_unhealthy(e) + def _handle_health_request(self): + if self._errored: + self._send_unhealthy(ENGINE_DEAD_ERROR) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2c0244709f5b..9fa44a956299 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -72,7 +72,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10) + await asyncio.sleep(10.) await engine_client.do_log_stats() if not engine_args.disable_log_stats: From f7fdf69d1a7b61a382c3ba9f4453791680bbf23c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 19:04:52 +0000 Subject: [PATCH 065/116] cleanup health check --- vllm/engine/multiprocessing/engine.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 0baaa2043f08..3cb61dd2d423 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -303,11 +303,10 @@ def _handle_abort_request(self, request: RPCAbortRequest): def _handle_health_request(self): if self._errored: self._send_unhealthy(ENGINE_DEAD_ERROR) - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._send_unhealthy(e) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" From 7e61cdb8b86d557a708fcc854acc7ed9efa6217c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 12:17:56 -0700 Subject: [PATCH 066/116] Use pickle for requests too --- vllm/engine/multiprocessing/client.py | 54 +++++++++++++++++---------- vllm/engine/multiprocessing/engine.py | 8 +++- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9ade3c3b8fba..e546a365cea9 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -1,4 +1,5 @@ import asyncio +import copy import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, @@ -252,7 +253,7 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, """Send an RPC request that is expecting data back.""" # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) + await socket.send_multipart((pickle.dumps(request), ), copy=False) # Make sure the server responds in time. if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: @@ -283,7 +284,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_T, socket: Socket): """Send one-way RPC request to trigger an action.""" - await socket.send_multipart((cloudpickle.dumps(request), )) + await socket.send_multipart((pickle.dumps(request), )) async def _await_ack(self, error_message: str, socket: Socket): "Await acknowledgement that a request succeeded." @@ -439,37 +440,52 @@ async def generate( self.output_queues[request_id] = queue try: - # 2) Send the RPCGenerateRequest to the MQLLMEngine. - await self.input_socket.send_multipart((cloudpickle.dumps( + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( RPCGenerateRequest( inputs=inputs, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) + prompt_adapter_request=prompt_adapter_request)) - # 3) Stream the RequestOutputs from the output queue. Note + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note # that the output_loop pushes RequestOutput objects to this # queue after pulling them from the zmq socket. finished = False - while not finished: - request_output = await queue.get() - - if isinstance(request_output, BaseException): - raise request_output - - finished = request_output.finished - yield request_output - + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self._errored: + await self.abort(request_id) finally: # TODO: check if excepted requests are getting here. self.output_queues.pop(request_id) - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 3cb61dd2d423..58965976aa8e 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -243,10 +243,14 @@ def handle_new_input(self): """Handle new input from the socket""" try: while self.input_socket.poll(timeout=0) != 0: - message = self.input_socket.recv(copy=False) - request = cloudpickle.loads(message.buffer) + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs self._handle_generate_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) From 3e84c8c87725c542491d5cb9550bea5f5b6dc273 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 19:47:13 +0000 Subject: [PATCH 067/116] Remove hwm --- vllm/engine/multiprocessing/client.py | 3 --- vllm/engine/multiprocessing/engine.py | 12 +++++------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index e546a365cea9..8345fcc6d512 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -74,12 +74,10 @@ def __init__(self, ipc_path: str): # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) - # self.input_socket.set_hwm(0) self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") # Receive streams of RequestOutput from the MQLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - # self.output_socket.set_hwm(0) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") # IPC path for ack of check_health requests. @@ -483,7 +481,6 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) finally: - # TODO: check if excepted requests are getting here. self.output_queues.pop(request_id) async def encode(self, *args, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 58965976aa8e..9c14a3fd05dc 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -31,15 +31,15 @@ class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to enable use - in asynchronous manner. It runs a background loop and uses zeromq to - receive new requests and stream outputs incrementally to another process. + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. - The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest - is received by the input_socket. + The :class:`LLMEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal - :class:`LLMEngine.step()` and sends the RequestOutputs back over + :class:`LLMEngine.step()`, and sends the RequestOutputs back over the output_socket. If use_async_sockets is set, the logic associated with reading new @@ -73,12 +73,10 @@ def __init__(self, # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) - # self.input_socket.set_hwm(0) self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") # Send output stream back to client. self.output_socket = self.ctx.socket(zmq.constants.PUSH) - # self.output_socket.set_hwm(0) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") # Send health status back to client. From 2559813d6570c22f0db14dbf7f7cec7a03a7816a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 15:09:12 -0700 Subject: [PATCH 068/116] Simplify configs setup --- vllm/engine/multiprocessing/client.py | 76 +++++---------------------- vllm/engine/multiprocessing/engine.py | 23 +++----- vllm/entrypoints/openai/api_server.py | 11 ++-- 3 files changed, 26 insertions(+), 84 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8345fcc6d512..43c213cfc1bd 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -11,8 +11,7 @@ from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import DecodingConfig, EngineConfig, LoRAConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, @@ -67,11 +66,23 @@ class MQLLMEngineClient: every N seconds, confirming the engine is healthy """ - def __init__(self, ipc_path: str): + def __init__(self, ipc_path: str, engine_config: EngineConfig): self.context = zmq.asyncio.Context() self._errored = False self.dead_error = ENGINE_DEAD_ERROR + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") @@ -207,21 +218,8 @@ async def setup(self): # Wait until server is ready. await self._wait_for_server_rpc(socket) - # Get the configs. - self.model_config = await self._get_model_config_rpc(socket) - self.decoding_config = await self._get_decoding_config_rpc(socket) self.tracing_flag = await self._is_tracing_enabled_rpc(socket) - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await - self._get_scheduler_config_rpc(socket)), - parallel_config=(await self._get_parallel_config_rpc(socket)), - enable_lora=bool(await self._get_lora_config_rpc(socket)), - ) - # Start health_loop. self.health_loop = asyncio.create_task( self.run_check_health_loop( @@ -330,52 +328,6 @@ async def _notify_ready(self, socket: Socket): await self._send_one_way_rpc_request( request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) - async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server", - socket=socket) - - async def _get_decoding_config_rpc(self, socket: Socket) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server", - socket=socket) - - async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server", - socket=socket) - - async def _get_scheduler_config_rpc(self, - socket: Socket) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server", - socket=socket) - - async def _get_lora_config_rpc(self, socket: Socket) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server", - socket=socket) - async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: """Get is_tracing_enabled flag from the RPCServer""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 9c14a3fd05dc..8da8629acbba 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -164,17 +164,7 @@ def run_startup_loop(self) -> None: request: RPCStartupRequest = pickle.loads(message.buffer) # Handle the query from the Client. - if request == RPCStartupRequest.GET_MODEL_CONFIG: - response = self.engine.get_model_config() - elif request == RPCStartupRequest.GET_DECODING_CONFIG: - response = self.engine.get_decoding_config() - elif request == RPCStartupRequest.GET_LORA_CONFIG: - response = self.engine.get_lora_config() - elif request == RPCStartupRequest.GET_SCHEDULER_CONFIG: - response = self.engine.get_scheduler_config() - elif request == RPCStartupRequest.GET_PARALLEL_CONFIG: - response = self.engine.get_parallel_config() - elif request == RPCStartupRequest.GET_TRACING_ENABLED: + if request == RPCStartupRequest.GET_TRACING_ENABLED: response = self.engine.is_tracing_enabled() elif request == RPCStartupRequest.IS_SERVER_READY: response = VLLM_RPC_SUCCESS_STR @@ -183,12 +173,11 @@ def run_startup_loop(self) -> None: # Breakout of loop once client is ready. client_is_ready = True - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) - except Exception as e: - socket.send_multipart((identity, pickle.dumps(e)), - copy=False) + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) def run_engine_loop(self): """Core busy loop of the LLMEngine.""" @@ -305,7 +294,7 @@ def _handle_abort_request(self, request: RPCAbortRequest): def _handle_health_request(self): if self._errored: self._send_unhealthy(ENGINE_DEAD_ERROR) - + # Raises error if unhealthy. self.engine.check_health() self._send_healthy() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9fa44a956299..7d071cc61240 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -149,11 +149,6 @@ async def build_async_engine_client_from_engine_args( logger.info("Multiprocessing frontend to use %s for IPC Path.", ipc_path) - # Build RPCClient, which conforms to EngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - mp_engine_client = MQLLMEngineClient(ipc_path) - # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process @@ -166,6 +161,12 @@ async def build_async_engine_client_from_engine_args( engine_process.start() logger.info("Started engine process with PID %d", engine_process.pid) + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) + try: while True: try: From d0a0f8be85fdde96d0437ef4aa48740e46bbf025 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 22:52:41 +0000 Subject: [PATCH 069/116] stash --- vllm/engine/multiprocessing/client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8345fcc6d512..093d7831e363 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -92,6 +92,7 @@ def __init__(self, ipc_path: str): self.output_loop = asyncio.create_task(self.run_output_handler_loop()) # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None @staticmethod @@ -171,9 +172,10 @@ async def run_output_handler_loop(self): self._errored = True exception = rpc_error.exception else: - # MPLLMEngine should always return an RPCError - # when an issue arises. If we are here, we are in a - # bad state and should shut down the server. + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. error: BaseException = request_outputs logger.error( "Received Exception %s rather than RPCError from " From 021fed35c7c4afbc7b8f193785db26ab6b9ede1c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 10 Sep 2024 23:41:55 +0000 Subject: [PATCH 070/116] added tests --- tests/mq_llm_engine/__init__.py | 0 tests/mq_llm_engine/test_errors.py | 89 ++++++++++++++++++++++++++++++ tests/mq_llm_engine/utils.py | 34 ++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 tests/mq_llm_engine/__init__.py create mode 100644 tests/mq_llm_engine/test_errors.py create mode 100644 tests/mq_llm_engine/utils.py diff --git a/tests/mq_llm_engine/__init__.py b/tests/mq_llm_engine/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/mq_llm_engine/test_errors.py b/tests/mq_llm_engine/test_errors.py new file mode 100644 index 000000000000..939f226e5c5c --- /dev/null +++ b/tests/mq_llm_engine/test_errors.py @@ -0,0 +1,89 @@ +import asyncio +import pytest +import tempfile +import uuid + +from unittest.mock import Mock + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing import ENGINE_DEAD_ERROR + +from vllm.engine.multiprocessing.engine import MQLLMEngine + +from vllm.usage.usage_lib import UsageContext + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine + + +MODEL = "Qwen/Qwen2-0.5B-Instruct" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError("foo") + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + +def run_with_evil_forward(engine_args: AsyncEngineArgs, + ipc_path: str): + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context= UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + # Raise error during first forward pass. + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR) + engine.start() + + +@pytest.mark.asyncio +async def test_health_loop(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_forward) as engine: + + # Make client. + client = await engine.make_client() + + POLL_TIME = 1.0 + health_task = asyncio.create_task( + client.run_check_health_loop(timeout=POLL_TIME)) + + # Server should be healthy. + await asyncio.sleep(POLL_TIME * 3) + await client.check_health() + + # Throws an error in engine.step(). + try: + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + except Exception as e: + # First exception should be a RAISED_ERROR + assert repr(e) == repr(RAISED_ERROR) + + # Engine is errored, should get ENGINE_DEAD_ERROR. + try: + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + except Exception as e: + # First exception should be a RAISED_ERROR + assert e == ENGINE_DEAD_ERROR, ( + "Engine should be dead and raise ENGINE_DEAD_ERROR") + + + asyncio.sleep(POLL_TIME * 3) + try: + await client.check_health() + except Exception as e: + assert e == ENGINE_DEAD_ERROR, ( + "Engine should be dead and raise ENGINE_DEAD_ERROR") + + await health_task + client.close() + + diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py new file mode 100644 index 000000000000..539df1ea085f --- /dev/null +++ b/tests/mq_llm_engine/utils.py @@ -0,0 +1,34 @@ +import multiprocessing +from typing import Callable +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient + +class RemoteMQLLMEngine: + def __init__(self, + run_fn: Callable, + engine_args: AsyncEngineArgs, + ipc_path: str) -> None: + + self.engine_args = engine_args + self.ipc_path = ipc_path + context = multiprocessing.get_context("spawn") + self.proc = context.Process(target=run_fn, + args=(engine_args, ipc_path)) + self.proc.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.kill() + + async def make_client(self) -> MQLLMEngineClient: + engine_config = self.engine_args.create_engine_config() + client = MQLLMEngineClient(self.ipc_path, engine_config) + while True: + try: + await client.setup() + break + except TimeoutError: + assert self.proc.is_alive() + return client From fd6ee43b4b592892a3ae8a383866222d0205f44a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 00:13:54 +0000 Subject: [PATCH 071/116] added failed health check --- tests/mq_llm_engine/test_errors.py | 89 +++++++++++++++++++++------ vllm/engine/multiprocessing/client.py | 38 +++++++----- vllm/engine/multiprocessing/engine.py | 1 + 3 files changed, 92 insertions(+), 36 deletions(-) diff --git a/tests/mq_llm_engine/test_errors.py b/tests/mq_llm_engine/test_errors.py index 939f226e5c5c..5011d98557bd 100644 --- a/tests/mq_llm_engine/test_errors.py +++ b/tests/mq_llm_engine/test_errors.py @@ -26,35 +26,35 @@ def tmp_socket(): yield f"ipc://{td}/{uuid.uuid4()}" def run_with_evil_forward(engine_args: AsyncEngineArgs, - ipc_path: str): - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context= UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) + ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context= UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) # Raise error during first forward pass. engine.engine.model_executor.execute_model = Mock( side_effect=RAISED_ERROR) - engine.start() + # Run engine. + engine.start() @pytest.mark.asyncio -async def test_health_loop(tmp_socket): +async def test_evil_forward(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, ipc_path=tmp_socket, run_fn=run_with_evil_forward) as engine: - # Make client. client = await engine.make_client() - POLL_TIME = 1.0 - health_task = asyncio.create_task( - client.run_check_health_loop(timeout=POLL_TIME)) + # Fast health probe. + fast_health_probe_task = asyncio.create_task( + client.run_check_health_loop(timeout=1.0)) - # Server should be healthy. - await asyncio.sleep(POLL_TIME * 3) + # Server should be healthy after initial probe. + await asyncio.sleep(2.0) await client.check_health() - # Throws an error in engine.step(). + # Throws an error in first forward pass. try: async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), @@ -63,6 +63,7 @@ async def test_health_loop(tmp_socket): except Exception as e: # First exception should be a RAISED_ERROR assert repr(e) == repr(RAISED_ERROR) + assert client.errored # Engine is errored, should get ENGINE_DEAD_ERROR. try: @@ -71,19 +72,67 @@ async def test_health_loop(tmp_socket): request_id=uuid.uuid4()): pass except Exception as e: - # First exception should be a RAISED_ERROR + # Next excpetion should be an ENGINE_DEAD_ERROR assert e == ENGINE_DEAD_ERROR, ( "Engine should be dead and raise ENGINE_DEAD_ERROR") - + assert client.errored - asyncio.sleep(POLL_TIME * 3) + await asyncio.sleep(2.0) try: await client.check_health() except Exception as e: - assert e == ENGINE_DEAD_ERROR, ( - "Engine should be dead and raise ENGINE_DEAD_ERROR") + assert repr(e) == repr(RAISED_ERROR), ( + "Health check raise the original error.") - await health_task + # Cleanup + await fast_health_probe_task client.close() +def run_with_evil_model_executor_health( + engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context= UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + # Raise error during first forward pass. + engine.engine.model_executor.check_health = Mock( + side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + +@pytest.mark.asyncio +async def test_failed_health_check(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: + + client = await engine.make_client() + assert client.is_running + + # Health probe should throw RAISED_ERROR. + await asyncio.sleep(10) + try: + await client.check_health() + except Exception as e: + # First exception should be a RAISED_ERROR + assert repr(e) == repr(RAISED_ERROR), ( + "Health check raise the original error.") + assert client.errored + + # Generate call should throw ENGINE_DEAD_ERROR + try: + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + except Exception as e: + # Next excpetion should be an ENGINE_DEAD_ERROR + assert e == ENGINE_DEAD_ERROR, ( + "Engine should be dead and raise ENGINE_DEAD_ERROR") + assert client.errored + + # Cleanup + client.close() + \ No newline at end of file diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b14dab1d5893..eac4e581e005 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -68,7 +68,7 @@ class MQLLMEngineClient: def __init__(self, ipc_path: str, engine_config: EngineConfig): self.context = zmq.asyncio.Context() - self._errored = False + self._errored_with: Optional[BaseException] = None self.dead_error = ENGINE_DEAD_ERROR # Get the configs. @@ -162,8 +162,7 @@ async def run_check_health_loop(self, timeout: int): logger.debug("Shutting down MQLLMEngineClient check health loop.") except Exception as e: - logger.exception(repr(e)) - self._errored = True + self.raise_exception(e) async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to request Queues""" @@ -179,8 +178,6 @@ async def run_output_handler_loop(self): if isinstance(request_outputs, RPCError): rpc_error: RPCError = request_outputs request_id = rpc_error.request_id - if rpc_error.is_engine_errored: - self._errored = True exception = rpc_error.exception else: # MPLLMEngine should always return an RPCError to @@ -191,10 +188,13 @@ async def run_output_handler_loop(self): logger.error( "Received Exception %s rather than RPCError from " "MPLLMEngine. This should never happen.", error) - self._errored = True request_id = None exception = error + # If this is the first error, set _errored_with + if not self._errored_with: + self._errored_with = exception + if request_id is None: for queue_i in tuple(self.output_queues.values()): queue_i.put_nowait(exception) @@ -244,6 +244,12 @@ def close(self): self.health_loop.cancel() self.output_loop.cancel() + def raise_exception(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, @@ -353,24 +359,23 @@ async def do_log_stats(self): async def check_health(self): """ The check health loop probes the health status of the - Engine's health every N seconds and sets _errored if - the engine is unhealth. So check_health just raises - an ENGINE_DEAD_ERROR if we find self._errored + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. """ - if self._errored: - raise ENGINE_DEAD_ERROR + if self._errored_with is not None: + raise self._errored_with @property def is_running(self) -> bool: - return not self._errored + return not self.errored @property def is_stopped(self) -> bool: - return self._errored + return self.errored @property def errored(self) -> bool: - return self._errored + return self._errored_with is not None async def generate( self, @@ -383,7 +388,8 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - if self._errored: + # If already dead, error out. + if self.errored: raise ENGINE_DEAD_ERROR # 1) Create output queue for this requests. @@ -432,7 +438,7 @@ async def generate( yield request_output finally: # Request was canceled by the client. - if not finished and not self._errored: + if not finished and not self.errored: await self.abort(request_id) finally: self.output_queues.pop(request_id) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 8da8629acbba..6ba695c73f32 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -131,6 +131,7 @@ def start(self): except KeyboardInterrupt: logger.debug("Shutting down MQLLMEngine.") finally: + logger.debug("MQLLMEngine is shut down.") self.cleanup() def cleanup(self): From ccb43a392c9452990631289b3ef517b5f92fc550 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 00:18:24 +0000 Subject: [PATCH 072/116] rename --- ...{test_errors.py => test_error_handling.py} | 65 +++++++++++-------- tests/mq_llm_engine/utils.py | 9 +-- vllm/engine/multiprocessing/client.py | 5 +- 3 files changed, 45 insertions(+), 34 deletions(-) rename tests/mq_llm_engine/{test_errors.py => test_error_handling.py} (73%) diff --git a/tests/mq_llm_engine/test_errors.py b/tests/mq_llm_engine/test_error_handling.py similarity index 73% rename from tests/mq_llm_engine/test_errors.py rename to tests/mq_llm_engine/test_error_handling.py index 5011d98557bd..41b3c59338ae 100644 --- a/tests/mq_llm_engine/test_errors.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -1,43 +1,41 @@ import asyncio -import pytest import tempfile import uuid - from unittest.mock import Mock +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing import ENGINE_DEAD_ERROR - from vllm.engine.multiprocessing.engine import MQLLMEngine - from vllm.usage.usage_lib import UsageContext -from tests.mq_llm_engine.utils import RemoteMQLLMEngine - - MODEL = "Qwen/Qwen2-0.5B-Instruct" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) RAISED_ERROR = KeyError("foo") + @pytest.fixture(scope="function") def tmp_socket(): with tempfile.TemporaryDirectory() as td: yield f"ipc://{td}/{uuid.uuid4()}" -def run_with_evil_forward(engine_args: AsyncEngineArgs, - ipc_path: str): + +def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): # Make engine. - engine = MQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context= UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) # Raise error during first forward pass. - engine.engine.model_executor.execute_model = Mock( - side_effect=RAISED_ERROR) + engine.engine.model_executor.execute_model = Mock(side_effect=RAISED_ERROR) # Run engine. engine.start() + @pytest.mark.asyncio async def test_evil_forward(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, @@ -72,7 +70,7 @@ async def test_evil_forward(tmp_socket): request_id=uuid.uuid4()): pass except Exception as e: - # Next excpetion should be an ENGINE_DEAD_ERROR + # Next exception should be an ENGINE_DEAD_ERROR assert e == ENGINE_DEAD_ERROR, ( "Engine should be dead and raise ENGINE_DEAD_ERROR") assert client.errored @@ -89,24 +87,26 @@ async def test_evil_forward(tmp_socket): client.close() -def run_with_evil_model_executor_health( - engine_args: AsyncEngineArgs, ipc_path: str): +def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, + ipc_path: str): # Make engine. - engine = MQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context= UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) # Raise error during first forward pass. - engine.engine.model_executor.check_health = Mock( - side_effect=RAISED_ERROR) + engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) # Run engine. engine.start() + @pytest.mark.asyncio async def test_failed_health_check(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_model_executor_health) as engine: + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: client = await engine.make_client() assert client.is_running @@ -128,11 +128,22 @@ async def test_failed_health_check(tmp_socket): request_id=uuid.uuid4()): pass except Exception as e: - # Next excpetion should be an ENGINE_DEAD_ERROR + # Next exception should be an ENGINE_DEAD_ERROR assert e == ENGINE_DEAD_ERROR, ( "Engine should be dead and raise ENGINE_DEAD_ERROR") assert client.errored # Cleanup client.close() - \ No newline at end of file + + +def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + # Raise error during abort call. + engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + # Run engine. + engine.start() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 539df1ea085f..0ef649781c46 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -1,18 +1,19 @@ import multiprocessing from typing import Callable + from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing.client import MQLLMEngineClient + class RemoteMQLLMEngine: - def __init__(self, - run_fn: Callable, - engine_args: AsyncEngineArgs, + + def __init__(self, run_fn: Callable, engine_args: AsyncEngineArgs, ipc_path: str) -> None: self.engine_args = engine_args self.ipc_path = ipc_path context = multiprocessing.get_context("spawn") - self.proc = context.Process(target=run_fn, + self.proc = context.Process(target=run_fn, args=(engine_args, ipc_path)) self.proc.start() diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index eac4e581e005..5495c55f5e86 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -181,8 +181,8 @@ async def run_output_handler_loop(self): exception = rpc_error.exception else: # MPLLMEngine should always return an RPCError to - # the output_socket when an issue arises. - # If we are here, we are in a bad state and + # the output_socket when an issue arises. + # If we are here, we are in a bad state and # should shut down the server. error: BaseException = request_outputs logger.error( @@ -249,7 +249,6 @@ def raise_exception(self, e: BaseException): if self._errored_with is None: self._errored_with = e - async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, From 1aa0823c36bdad92795eab24fbcad297745510a5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 00:44:06 +0000 Subject: [PATCH 073/116] added failed abort test --- tests/mq_llm_engine/test_error_handling.py | 48 ++++++++++++++++++++++ vllm/engine/multiprocessing/engine.py | 6 +++ 2 files changed, 54 insertions(+) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 41b3c59338ae..b57fca5e12a1 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -147,3 +147,51 @@ def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) # Run engine. engine.start() + +@pytest.mark.asyncio +async def test_failed_abort(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # Firsh check health should work. + await asyncio.sleep(10) + await client.check_health() + + # Trigger an abort on the client side. + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") + + # Immediately should tigger error. + try: + await client.check_health() + except Exception as e: + # First exception should be a RAISED_ERROR + assert repr(e) == repr(RAISED_ERROR), ( + "Health check raise the original error.") + assert client.errored + + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + try: + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams( + max_tokens=2000), + request_id=uuid.uuid4()): + pass + except Exception as e: + # Next exception should be an ENGINE_DEAD_ERROR + assert e == ENGINE_DEAD_ERROR, ( + "Engine should be dead and raise ENGINE_DEAD_ERROR") + assert client.errored + + await abort_task + + client.close() diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 6ba695c73f32..e67ef4e0d724 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -204,6 +204,12 @@ def run_engine_dead_loop(self): """Loop for replying to all requests that we are dead.""" if not self._errored: raise ValueError("In dead loop, but found _errored=False") + + # Send a single message that we are dead. + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR) + self._send_outputs(rpc_err) while True: # Poll until there is a request From fe22fe22193b94425cb46964c123421629334b34 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 00:48:16 +0000 Subject: [PATCH 074/116] formatting --- .buildkite/test-pipeline.yaml | 4 +- .../entrypoints/openai/test_mp_api_server.py | 40 ------------ tests/mq_llm_engine/test_error_handling.py | 61 +++++++++++++++---- vllm/engine/multiprocessing/engine.py | 2 +- 4 files changed, 53 insertions(+), 54 deletions(-) delete mode 100644 tests/entrypoints/openai/test_mp_api_server.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a0c7b7442b3b..7a9b2da34081 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,13 +43,15 @@ steps: fast_check: true source_file_dependencies: - vllm/ + - tests/mq_llm_engine - tests/async_engine - tests/test_inputs - tests/multimodal - tests/test_utils - tests/worker commands: - - pytest -v -s async_engine # Async Engine + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py deleted file mode 100644 index fbfe0db19dd0..000000000000 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -import pytest - -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser - - -@pytest.mark.asyncio -async def test_mp_crash_detection(): - - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index b57fca5e12a1..0070158076b0 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -1,5 +1,6 @@ import asyncio import tempfile +import time import uuid from unittest.mock import Mock @@ -10,7 +11,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing import ENGINE_DEAD_ERROR from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser MODEL = "Qwen/Qwen2-0.5B-Instruct" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) @@ -148,12 +152,12 @@ def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): # Run engine. engine.start() + @pytest.mark.asyncio async def test_failed_abort(tmp_socket): - with RemoteMQLLMEngine( - engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: client = await engine.make_client() assert client.is_running @@ -161,13 +165,13 @@ async def test_failed_abort(tmp_socket): # Firsh check health should work. await asyncio.sleep(10) await client.check_health() - + # Trigger an abort on the client side. async def bad_abort_after_2s(): - await asyncio.sleep(2.0) + await asyncio.sleep(2.0) await client.abort(request_id="foo") - # Immediately should tigger error. + # Immediately should trigger error. try: await client.check_health() except Exception as e: @@ -181,10 +185,10 @@ async def bad_abort_after_2s(): # Exception in abort() will happen during this generation. try: - async for _ in client.generate(inputs="Hello my name is", - sampling_params=SamplingParams( - max_tokens=2000), - request_id=uuid.uuid4()): + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=2000), + request_id=uuid.uuid4()): pass except Exception as e: # Next exception should be an ENGINE_DEAD_ERROR @@ -193,5 +197,38 @@ async def bad_abort_after_2s(): assert client.errored await abort_task - + client.close() + + +@pytest.mark.asyncio +async def test_mp_crash_detection(): + + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + # use an invalid tensor_parallel_size to trigger the + # error in the server + args.tensor_parallel_size = 65536 + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") + + +@pytest.mark.asyncio +async def test_mp_cuda_init(): + # it should not crash, when cuda is initialized + # in the API server process + import torch + torch.cuda.init() + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index e67ef4e0d724..442dae7332f2 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -204,7 +204,7 @@ def run_engine_dead_loop(self): """Loop for replying to all requests that we are dead.""" if not self._errored: raise ValueError("In dead loop, but found _errored=False") - + # Send a single message that we are dead. rpc_err = RPCError(request_id=None, is_engine_errored=True, From 3ce87021eaf37cafd80873ec1ad23a69fb5fe108 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 19:13:23 -0700 Subject: [PATCH 075/116] Some more startup RPC simplification --- vllm/engine/multiprocessing/__init__.py | 13 ++++----- vllm/engine/multiprocessing/client.py | 39 +++++++++++-------------- vllm/engine/multiprocessing/engine.py | 14 +++++---- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index f028276e1931..09aa42498183 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -48,13 +48,12 @@ class RPCHealthRequest: class RPCStartupRequest(Enum): IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - GET_TRACING_ENABLED = 7 - CLIENT_IS_READY = 8 + CLIENT_IS_READY = 2 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 5495c55f5e86..5e2a2a9ccc5b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -18,7 +18,8 @@ IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest) + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -218,9 +219,9 @@ async def setup(self): with self.get_data_socket() as socket: # Wait until server is ready. - await self._wait_for_server_rpc(socket) + response = await self._wait_for_server_rpc(socket) - self.tracing_flag = await self._is_tracing_enabled_rpc(socket) + self.tracing_flag = response.tracing_enabled # Start health_loop. self.health_loop = asyncio.create_task( @@ -249,7 +250,8 @@ def raise_exception(self, e: BaseException): if self._errored_with is None: self._errored_with = e - async def _send_get_data_rpc_request(self, request: RPCStartupRequest, + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, expected_type: Any, error_message: str, socket: Socket) -> Any: @@ -283,14 +285,15 @@ async def _send_get_data_rpc_request(self, request: RPCStartupRequest, return data - async def _send_one_way_rpc_request(self, request: RPC_REQUEST_T, + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, socket: Socket): """Send one-way RPC request to trigger an action.""" await socket.send_multipart((pickle.dumps(request), )) async def _await_ack(self, error_message: str, socket: Socket): - "Await acknowledgement that a request succeeded." + """Await acknowledgement that a request succeeded.""" if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("MQLLMEngine didn't reply within " @@ -298,7 +301,8 @@ async def _await_ack(self, error_message: str, socket: Socket): await self._check_success(error_message, socket) - async def _check_success(self, error_message: str, socket: Socket): + @staticmethod + async def _check_success(error_message: str, socket: Socket): frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) @@ -320,14 +324,14 @@ async def get_model_config(self) -> ModelConfig: async def is_tracing_enabled(self) -> bool: return self.tracing_flag - async def _wait_for_server_rpc(self, socket: Socket): + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: """Wait for the RPCServer to start up.""" - await self._send_one_way_rpc_request( - request=RPCStartupRequest.IS_SERVER_READY, socket=socket) - - await self._await_ack(error_message="Unable to start RPC Server", - socket=socket) + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" @@ -335,15 +339,6 @@ async def _notify_ready(self, socket: Socket): await self._send_one_way_rpc_request( request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) - async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCStartupRequest.GET_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server", - socket=socket) - async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 442dae7332f2..94b531e8f1ec 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import Iterator, List, Union +from typing import Any, Iterator, List, Union import cloudpickle import zmq @@ -13,7 +13,8 @@ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest) + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -160,15 +161,16 @@ def run_startup_loop(self) -> None: # Loop until the RPCClient has all the data it needs. client_is_ready = False while not client_is_ready: + response: Any try: identity, message = socket.recv_multipart(copy=False) request: RPCStartupRequest = pickle.loads(message.buffer) # Handle the query from the Client. - if request == RPCStartupRequest.GET_TRACING_ENABLED: - response = self.engine.is_tracing_enabled() - elif request == RPCStartupRequest.IS_SERVER_READY: - response = VLLM_RPC_SUCCESS_STR + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) elif request == RPCStartupRequest.CLIENT_IS_READY: response = VLLM_RPC_SUCCESS_STR # Breakout of loop once client is ready. From 1f3fc246b22533f22d7335e8a30fb7fda2d39942 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 22:09:33 -0700 Subject: [PATCH 076/116] fix yapf conflict --- vllm/engine/multiprocessing/client.py | 3 +++ vllm/engine/multiprocessing/engine.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 5e2a2a9ccc5b..7b218cb448e0 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -13,6 +13,8 @@ from vllm.config import DecodingConfig, EngineConfig, LoRAConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, @@ -20,6 +22,7 @@ RPCError, RPCGenerateRequest, RPCHealthRequest, RPCStartupRequest, RPCStartupResponse) +# yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 94b531e8f1ec..6eb17dd6934c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -8,6 +8,8 @@ from vllm import AsyncEngineArgs, LLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, @@ -15,6 +17,7 @@ RPCError, RPCGenerateRequest, RPCHealthRequest, RPCStartupRequest, RPCStartupResponse) +# yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext From ead62dda25792cfd2318215257d23c464f708a65 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 11 Sep 2024 18:30:17 +0000 Subject: [PATCH 077/116] fix entrypoints tests --- vllm/engine/multiprocessing/client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7b218cb448e0..e3dc7295dd06 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -183,6 +183,7 @@ async def run_output_handler_loop(self): rpc_error: RPCError = request_outputs request_id = rpc_error.request_id exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored else: # MPLLMEngine should always return an RPCError to # the output_socket when an issue arises. @@ -194,9 +195,11 @@ async def run_output_handler_loop(self): "MPLLMEngine. This should never happen.", error) request_id = None exception = error + is_engine_errored = True - # If this is the first error, set _errored_with - if not self._errored_with: + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: self._errored_with = exception if request_id is None: From 672fb81f7021caba202b385741ed5911a52d02bd Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 18:59:06 +0000 Subject: [PATCH 078/116] stash --- docs/source/dev/profiling/profiling_index.rst | 4 +- tests/mq_llm_engine/test_error_handling.py | 60 +++++++++++++++---- tests/mq_llm_engine/utils.py | 16 ++++- vllm/engine/multiprocessing/__init__.py | 8 +-- vllm/engine/multiprocessing/client.py | 28 ++++++--- vllm/engine/multiprocessing/engine.py | 57 +++++++----------- vllm/envs.py | 6 +- 7 files changed, 111 insertions(+), 68 deletions(-) diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index e22d54729344..9e8b2f181756 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. - Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. - ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_TIMEOUT=1800000`` Example commands and usage: =========================== diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 0070158076b0..2fb2cac2fc59 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -9,8 +9,9 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing import ENGINE_DEAD_ERROR +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.lora.request import LoRARequest from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.usage.usage_lib import UsageContext @@ -75,9 +76,10 @@ async def test_evil_forward(tmp_socket): pass except Exception as e: # Next exception should be an ENGINE_DEAD_ERROR - assert e == ENGINE_DEAD_ERROR, ( + assert client.errored, "Client should be dead." + assert isinstance(e, MQEngineDeadError), ( "Engine should be dead and raise ENGINE_DEAD_ERROR") - assert client.errored + await asyncio.sleep(2.0) try: @@ -120,10 +122,9 @@ async def test_failed_health_check(tmp_socket): try: await client.check_health() except Exception as e: - # First exception should be a RAISED_ERROR + assert client.errored, "Client should be dead." assert repr(e) == repr(RAISED_ERROR), ( "Health check raise the original error.") - assert client.errored # Generate call should throw ENGINE_DEAD_ERROR try: @@ -132,10 +133,9 @@ async def test_failed_health_check(tmp_socket): request_id=uuid.uuid4()): pass except Exception as e: - # Next exception should be an ENGINE_DEAD_ERROR - assert e == ENGINE_DEAD_ERROR, ( + assert client.errored, "Client should be dead." + assert isinstance(e, MQEngineDeadError), ( "Engine should be dead and raise ENGINE_DEAD_ERROR") - assert client.errored # Cleanup client.close() @@ -163,7 +163,6 @@ async def test_failed_abort(tmp_socket): assert client.is_running # Firsh check health should work. - await asyncio.sleep(10) await client.check_health() # Trigger an abort on the client side. @@ -175,15 +174,15 @@ async def bad_abort_after_2s(): try: await client.check_health() except Exception as e: - # First exception should be a RAISED_ERROR + assert client.errored, "Client should be dead." assert repr(e) == repr(RAISED_ERROR), ( "Health check raise the original error.") - assert client.errored # Trigger an abort in 2s from now. abort_task = asyncio.create_task(bad_abort_after_2s()) # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR. try: async for _ in client.generate( inputs="Hello my name is", @@ -191,8 +190,9 @@ async def bad_abort_after_2s(): request_id=uuid.uuid4()): pass except Exception as e: + print(f"error is: {e}") # Next exception should be an ENGINE_DEAD_ERROR - assert e == ENGINE_DEAD_ERROR, ( + assert isinstance(e, MQEngineDeadError), ( "Engine should be dead and raise ENGINE_DEAD_ERROR") assert client.errored @@ -201,6 +201,42 @@ async def bad_abort_after_2s(): client.close() +@pytest.mark.asyncio +async def test_bad_request(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + # This should fail, but not crash the server. + try: + print("calling first generate") + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest("invalid-lora", 1, "invalid-path")): + pass + except Exception as e: + print("got exception") + assert isinstance(e, ValueError), ( + "Expected ValueError when a LoRARequest in llm_engine") + + # This request should be okay. + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): + pass + + # Confirm server is still running. + await asyncio.sleep(10.) + await client.check_health() + + # Shutdown. + client.close() + @pytest.mark.asyncio async def test_mp_crash_detection(): diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 0ef649781c46..4a24fd2819d0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -2,13 +2,25 @@ from typing import Callable from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.engine import MQLLMEngine from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.usage.usage_lib import UsageContext +def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Run engine. + engine.start() + class RemoteMQLLMEngine: - def __init__(self, run_fn: Callable, engine_args: AsyncEngineArgs, - ipc_path: str) -> None: + def __init__(self, engine_args: AsyncEngineArgs, + ipc_path: str, run_fn: Callable = run_normal) -> None: self.engine_args = engine_args self.ipc_path = ipc_path diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 09aa42498183..53ff9a9b1522 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -61,7 +61,7 @@ class RPCStartupResponse: REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] -ENGINE_DEAD_ERROR = MQEngineDeadError( - "Engine loop is not running. Inspect the output to find " - "the stacktrace of the error that caused the engine loop " - "to stop (MQEngineDeadError).") +def ENGINE_DEAD_ERROR(original_error: str) -> MQEngineDeadError: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {original_error}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7b218cb448e0..21c286f2bbc0 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -23,7 +23,7 @@ RPCHealthRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS +from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -173,6 +173,16 @@ async def run_output_handler_loop(self): try: while True: + # Poll, checking for ENGINE_DEAD + while self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue in tuple(self.output_queues.values()): + queue.put_nowait(ENGINE_DEAD_ERROR) + return + message: Frame = await self.output_socket.recv(copy=False) request_outputs = pickle.loads(message.buffer) @@ -183,6 +193,7 @@ async def run_output_handler_loop(self): rpc_error: RPCError = request_outputs request_id = rpc_error.request_id exception = rpc_error.exception + engine_errored = rpc_error.is_engine_errored else: # MPLLMEngine should always return an RPCError to # the output_socket when an issue arises. @@ -194,9 +205,10 @@ async def run_output_handler_loop(self): "MPLLMEngine. This should never happen.", error) request_id = None exception = error + engine_errored = True - # If this is the first error, set _errored_with - if not self._errored_with: + # If the engine is DEAD and this issue caused it. + if not self._errored_with and engine_errored: self._errored_with = exception if request_id is None: @@ -229,7 +241,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( self.run_check_health_loop( - timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS)) + timeout=VLLM_RPC_TIMEOUT)) # Notify MQLLMEngine client is ready to start sending requests. await self._notify_ready(socket) @@ -264,9 +276,9 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, await socket.send_multipart((pickle.dumps(request), ), copy=False) # Make sure the server responds in time. - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: raise TimeoutError("RPCServer didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + f"{VLLM_RPC_TIMEOUT} ms") # Await the data from the Server. frame = await socket.recv(copy=False) @@ -298,9 +310,9 @@ async def _send_one_way_rpc_request(request: RPC_REQUEST_T, async def _await_ack(self, error_message: str, socket: Socket): """Await acknowledgement that a request succeeded.""" - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: raise TimeoutError("MQLLMEngine didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS}ms") + f"{VLLM_RPC_TIMEOUT}ms") await self._check_success(error_message, socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 6eb17dd6934c..d17607b63575 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import Any, Iterator, List, Union +from typing import Any, Optional, Iterator, List, Union import cloudpickle import zmq @@ -91,7 +91,7 @@ def __init__(self, self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Error state. - self._errored = False + self._errored_with: Optional[BaseException] = None @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, @@ -124,14 +124,8 @@ def start(self): self.run_startup_loop() logger.debug("Starting Engine Loop.") self.run_engine_loop() - except Exception as e_core: - try: - logger.exception(repr(e_core)) - if self._errored: - logger.debug("Starting Dead Loop.") - self.run_engine_dead_loop() - except Exception as e_dead_loop: - logger.exception(repr(e_dead_loop)) + except Exception as e: + logger.exception(repr(e)) except KeyboardInterrupt: logger.debug("Shutting down MQLLMEngine.") finally: @@ -205,36 +199,16 @@ def run_engine_loop(self): if not self.use_async_sockets: self._send_outputs(request_outputs) - def run_engine_dead_loop(self): - """Loop for replying to all requests that we are dead.""" - if not self._errored: - raise ValueError("In dead loop, but found _errored=False") - - # Send a single message that we are dead. - rpc_err = RPCError(request_id=None, - is_engine_errored=True, - exception=ENGINE_DEAD_ERROR) - self._send_outputs(rpc_err) - - while True: - # Poll until there is a request - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for new requests in dead loop.") - - # Handle any new data, replying with EngineDeadError - self.handle_new_input() - def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" try: return self.engine.step() except Exception as e: - self._errored = True + self._set_errored(e) rpc_err = RPCError(request_id=None, is_engine_errored=True, exception=e) - logger.exception(repr(e)) self._send_outputs(rpc_err) raise e @@ -259,8 +233,7 @@ def handle_new_input(self): raise ValueError("Unknown RPCRequest Type: {request}") except Exception as e: - self._errored = True - logger.exception(repr(e)) + self._set_errored(e) self._send_unhealthy(e) raise e @@ -268,10 +241,10 @@ def _handle_generate_request(self, request: RPCGenerateRequest): """Handle RPCGenerateRequest by adding it to the LLMEngine.""" request_id = request.request_id - if self._errored: + if self._is_errored(): rpc_err = RPCError(request_id=request_id, is_engine_errored=True, - exception=ENGINE_DEAD_ERROR) + exception=ENGINE_DEAD_ERROR(self._errored_with)) self._send_outputs(rpc_err) try: @@ -304,8 +277,8 @@ def _handle_abort_request(self, request: RPCAbortRequest): logger.info("Aborted request %s.", request.request_id) def _handle_health_request(self): - if self._errored: - self._send_unhealthy(ENGINE_DEAD_ERROR) + if self._is_errored(): + self._send_unhealthy(self._errored_with) # Raises error if unhealthy. self.engine.check_health() @@ -332,6 +305,16 @@ def _async_socket_engine_callback(self, self._send_outputs(request_outputs) self.handle_new_input() + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + def _is_errored(self) -> bool: + """Check _errored status.""" + return self._errored_with is not None + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): diff --git a/vllm/envs.py b/vllm/envs.py index ed45047e9f8f..5ed226bbef06 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + VLLM_RPC_TIMEOUT: int = 5000 VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None @@ -383,8 +383,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "5000")), # If set, allow running the engine as a separate ray actor, # which is a deprecated feature soon to be removed. From 86312e44ecbb36153831998e89fcb3a7017f412c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 11 Sep 2024 19:34:54 +0000 Subject: [PATCH 079/116] fix Intel/TPU tests --- vllm/executor/cpu_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index ec9b24ce1318..87928f828712 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -103,6 +103,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] From 78b9e212c8506926f5a01918f992ea306885cef0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 19:49:57 +0000 Subject: [PATCH 080/116] fix --- vllm/engine/multiprocessing/client.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b32af25758..b5af02507fd1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -193,11 +193,7 @@ async def run_output_handler_loop(self): rpc_error: RPCError = request_outputs request_id = rpc_error.request_id exception = rpc_error.exception -<<<<<<< HEAD - engine_errored = rpc_error.is_engine_errored -======= is_engine_errored = rpc_error.is_engine_errored ->>>>>>> ead62dda25792cfd2318215257d23c464f708a65 else: # MPLLMEngine should always return an RPCError to # the output_socket when an issue arises. @@ -209,18 +205,11 @@ async def run_output_handler_loop(self): "MPLLMEngine. This should never happen.", error) request_id = None exception = error -<<<<<<< HEAD - engine_errored = True - - # If the engine is DEAD and this issue caused it. - if not self._errored_with and engine_errored: -======= is_engine_errored = True # Set to error state only on engine critical error # (and record only the first one) if is_engine_errored and not self._errored_with: ->>>>>>> ead62dda25792cfd2318215257d23c464f708a65 self._errored_with = exception if request_id is None: From 66c696157b5400a50f2b00510ead6a254b9900f3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 20:06:26 +0000 Subject: [PATCH 081/116] formatting --- tests/mq_llm_engine/test_error_handling.py | 31 +++++++++++----------- tests/mq_llm_engine/utils.py | 9 ++++--- vllm/engine/multiprocessing/__init__.py | 5 ++-- vllm/engine/multiprocessing/client.py | 12 ++++----- vllm/engine/multiprocessing/engine.py | 13 ++++----- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 2fb2cac2fc59..cd0b1cf349d9 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -11,9 +11,9 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.lora.request import LoRARequest from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.lora.request import LoRARequest from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -79,7 +79,6 @@ async def test_evil_forward(tmp_socket): assert client.errored, "Client should be dead." assert isinstance(e, MQEngineDeadError), ( "Engine should be dead and raise ENGINE_DEAD_ERROR") - await asyncio.sleep(2.0) try: @@ -203,20 +202,20 @@ async def bad_abort_after_2s(): @pytest.mark.asyncio async def test_bad_request(tmp_socket): - with RemoteMQLLMEngine( - engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: client = await engine.make_client() # This should fail, but not crash the server. try: print("calling first generate") - async for _ in client.generate( - inputs="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-1", - lora_request=LoRARequest("invalid-lora", 1, "invalid-path")): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest( + "invalid-lora", 1, + "invalid-path")): pass except Exception as e: print("got exception") @@ -224,19 +223,19 @@ async def test_bad_request(tmp_socket): "Expected ValueError when a LoRARequest in llm_engine") # This request should be okay. - async for _ in client.generate( - inputs="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-2"): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): pass - + # Confirm server is still running. await asyncio.sleep(10.) await client.check_health() - + # Shutdown. client.close() + @pytest.mark.asyncio async def test_mp_crash_detection(): diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 4a24fd2819d0..0b00c4209560 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -2,8 +2,8 @@ from typing import Callable from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing.engine import MQLLMEngine from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import MQLLMEngine from vllm.usage.usage_lib import UsageContext @@ -17,10 +17,13 @@ def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): # Run engine. engine.start() + class RemoteMQLLMEngine: - def __init__(self, engine_args: AsyncEngineArgs, - ipc_path: str, run_fn: Callable = run_normal) -> None: + def __init__(self, + engine_args: AsyncEngineArgs, + ipc_path: str, + run_fn: Callable = run_normal) -> None: self.engine_args = engine_args self.ipc_path = ipc_path diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 53ff9a9b1522..96b8dbc5ff08 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -61,7 +61,8 @@ class RPCStartupResponse: REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] -def ENGINE_DEAD_ERROR(original_error: str) -> MQEngineDeadError: + +def ENGINE_DEAD_ERROR(error: BaseException) -> MQEngineDeadError: return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {original_error}.") + f"find the original error {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b5af02507fd1..d223dee317c5 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -73,7 +73,6 @@ class MQLLMEngineClient: def __init__(self, ipc_path: str, engine_config: EngineConfig): self.context = zmq.asyncio.Context() self._errored_with: Optional[BaseException] = None - self.dead_error = ENGINE_DEAD_ERROR # Get the configs. self.model_config = engine_config.model_config @@ -179,8 +178,8 @@ async def run_output_handler_loop(self): # If errored, alert all running requests. if self.errored: - for queue in tuple(self.output_queues.values()): - queue.put_nowait(ENGINE_DEAD_ERROR) + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait(ENGINE_DEAD_ERROR) return message: Frame = await self.output_socket.recv(copy=False) @@ -241,8 +240,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_check_health_loop( - timeout=VLLM_RPC_TIMEOUT)) + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) # Notify MQLLMEngine client is ready to start sending requests. await self._notify_ready(socket) @@ -399,8 +397,8 @@ async def generate( """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. - if self.errored: - raise ENGINE_DEAD_ERROR + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index d17607b63575..e025ba698ed6 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,6 @@ import pickle from contextlib import contextmanager -from typing import Any, Optional, Iterator, List, Union +from typing import Any, Iterator, List, Optional, Union import cloudpickle import zmq @@ -241,7 +241,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): """Handle RPCGenerateRequest by adding it to the LLMEngine.""" request_id = request.request_id - if self._is_errored(): + if self._errored_with is not None: rpc_err = RPCError(request_id=request_id, is_engine_errored=True, exception=ENGINE_DEAD_ERROR(self._errored_with)) @@ -263,8 +263,9 @@ def _handle_generate_request(self, request: RPCGenerateRequest): # We do not set self._errored = True here, since the error # is due to an issue adding this request to the engine, # rather than an issue with the engine itself. + is_errored = self._errored_with is not None rpc_err = RPCError(request_id=request_id, - is_engine_errored=self._errored, + is_engine_errored=is_errored, exception=e) self._send_outputs(rpc_err) @@ -277,7 +278,7 @@ def _handle_abort_request(self, request: RPCAbortRequest): logger.info("Aborted request %s.", request.request_id) def _handle_health_request(self): - if self._is_errored(): + if self._errored_with is not None: self._send_unhealthy(self._errored_with) # Raises error if unhealthy. @@ -311,10 +312,6 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e - def _is_errored(self) -> bool: - """Check _errored status.""" - return self._errored_with is not None - def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From 6e1e2bb8c11c2f7979af09e58706cb64594d54b8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 20:18:58 +0000 Subject: [PATCH 082/116] cleanup --- tests/mq_llm_engine/test_error_handling.py | 4 ++++ vllm/engine/multiprocessing/engine.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index cd0b1cf349d9..6fc4f4586b87 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -34,6 +34,7 @@ def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): engine_args=engine_args, usage_context=UsageContext.UNKNOWN_CONTEXT, ipc_path=ipc_path) + # Raise error during first forward pass. engine.engine.model_executor.execute_model = Mock(side_effect=RAISED_ERROR) @@ -99,6 +100,7 @@ def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, engine_args=engine_args, usage_context=UsageContext.UNKNOWN_CONTEXT, ipc_path=ipc_path) + # Raise error during first forward pass. engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) @@ -146,8 +148,10 @@ def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): engine_args=engine_args, usage_context=UsageContext.UNKNOWN_CONTEXT, ipc_path=ipc_path) + # Raise error during abort call. engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + # Run engine. engine.start() diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index e025ba698ed6..789e89646687 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -308,7 +308,6 @@ def _async_socket_engine_callback(self, def _set_errored(self, e: BaseException): """Log and set errored status if this is the first issue.""" - logger.exception(repr(e)) if self._errored_with is None: self._errored_with = e From 610b34923a57a6cd9e5ee3e5e44c3d603311408f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 20:20:04 +0000 Subject: [PATCH 083/116] cleanup --- tests/entrypoints/openai/test_serving_chat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index f2a4e3e27cd3..de2a932199a0 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -76,9 +76,6 @@ def test_serving_chat_should_set_correct_max_tokens(): with suppress(Exception): asyncio.run(serving_chat.create_chat_completion(req)) - print(mock_engine) - print(mock_engine.generate) - print(mock_engine.generate.call_args) assert mock_engine.generate.call_args.args[1].max_tokens == 93 req.max_tokens = 10 From 28bb8a450af8a7355113aae936ff691b262f511e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 20:30:30 +0000 Subject: [PATCH 084/116] format --- vllm/engine/multiprocessing/__init__.py | 8 +++++++- vllm/engine/multiprocessing/engine.py | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 96b8dbc5ff08..8e61cc8a0454 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -62,7 +62,13 @@ class RPCStartupResponse: REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] -def ENGINE_DEAD_ERROR(error: BaseException) -> MQEngineDeadError: +def ENGINE_DEAD_ERROR( + error: Optional[BaseException]) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error") + return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " f"find the original error {repr(error)}.") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 789e89646687..5770524b5479 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -92,6 +92,15 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None + + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, From b266249f65b0a77979f490af50b0dfcdfe3eff19 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 21:26:57 +0000 Subject: [PATCH 085/116] fix poller --- vllm/engine/multiprocessing/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index d223dee317c5..e1d9892acd65 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -173,13 +173,13 @@ async def run_output_handler_loop(self): try: while True: # Poll, checking for ENGINE_DEAD - while self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: logger.debug("Waiting for output from MQLLMEngine.") # If errored, alert all running requests. if self.errored: for queue_j in tuple(self.output_queues.values()): - queue_j.put_nowait(ENGINE_DEAD_ERROR) + queue_j.put_nowait(ENGINE_DEAD_ERROR(self._errored_with)) return message: Frame = await self.output_socket.recv(copy=False) From f8036a58c8cf72a2983c135ea70445fd25a0dd37 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 21:58:12 +0000 Subject: [PATCH 086/116] add graceful shutdown on abort after client closed --- vllm/engine/multiprocessing/client.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index e1d9892acd65..6dd5dfb68516 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -303,11 +303,17 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, async def _send_one_way_rpc_request(request: RPC_REQUEST_T, socket: Socket): """Send one-way RPC request to trigger an action.""" - + # Raise handlable error for graceful shutdown. + if socket.closed: + raise MQClientClosedError() + await socket.send_multipart((pickle.dumps(request), )) async def _await_ack(self, error_message: str, socket: Socket): """Await acknowledgement that a request succeeded.""" + # Raise handlable error for graceful shutdown. + if socket.closed: + raise MQClientClosedError() if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: raise TimeoutError("MQLLMEngine didn't reply within " @@ -317,6 +323,10 @@ async def _await_ack(self, error_message: str, socket: Socket): @staticmethod async def _check_success(error_message: str, socket: Socket): + # Raise handlable error for graceful shutdown. + if socket.closed: + raise MQClientClosedError() + frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) From a649f754c7b0cddded5735da4eb5087483319854 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 22:08:38 +0000 Subject: [PATCH 087/116] cleanup formatting --- vllm/engine/multiprocessing/__init__.py | 4 ++-- vllm/engine/multiprocessing/client.py | 8 +++++--- vllm/engine/multiprocessing/engine.py | 2 -- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 8e61cc8a0454..cf0b36bee7aa 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -63,11 +63,11 @@ class RPCStartupResponse: def ENGINE_DEAD_ERROR( - error: Optional[BaseException]) -> MQEngineDeadError: + error: Optional[BaseException] = None) -> MQEngineDeadError: if error is None: return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " - f"find the original error") + "find the original error") return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6dd5dfb68516..e7af5aed9013 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -173,13 +173,15 @@ async def run_output_handler_loop(self): try: while True: # Poll, checking for ENGINE_DEAD - while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: logger.debug("Waiting for output from MQLLMEngine.") # If errored, alert all running requests. if self.errored: for queue_j in tuple(self.output_queues.values()): - queue_j.put_nowait(ENGINE_DEAD_ERROR(self._errored_with)) + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) return message: Frame = await self.output_socket.recv(copy=False) @@ -306,7 +308,7 @@ async def _send_one_way_rpc_request(request: RPC_REQUEST_T, # Raise handlable error for graceful shutdown. if socket.closed: raise MQClientClosedError() - + await socket.send_multipart((pickle.dumps(request), )) async def _await_ack(self, error_message: str, socket: Socket): diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 5770524b5479..df13532dcd04 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -92,7 +92,6 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None - @property def dead_error(self) -> BaseException: @@ -101,7 +100,6 @@ def dead_error(self) -> BaseException: else: return ENGINE_DEAD_ERROR() - @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From 5b3535d1188a7b2f6f59dbeb15096326e7b6c1b2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 23:24:27 +0000 Subject: [PATCH 088/116] added test abort --- tests/mq_llm_engine/test_abort.py | 0 tests/mq_llm_engine/test_error_handling.py | 93 +++++++--------------- vllm/engine/multiprocessing/__init__.py | 2 +- vllm/engine/multiprocessing/client.py | 43 ++++------ vllm/engine/multiprocessing/engine.py | 4 +- 5 files changed, 46 insertions(+), 96 deletions(-) create mode 100644 tests/mq_llm_engine/test_abort.py diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 6fc4f4586b87..27ebde7fef0e 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -19,7 +19,8 @@ MODEL = "Qwen/Qwen2-0.5B-Instruct" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) -RAISED_ERROR = KeyError("foo") +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" @pytest.fixture(scope="function") @@ -36,7 +37,8 @@ def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): ipc_path=ipc_path) # Raise error during first forward pass. - engine.engine.model_executor.execute_model = Mock(side_effect=RAISED_ERROR) + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR(RAISED_VALUE)) # Run engine. engine.start() @@ -50,46 +52,32 @@ async def test_evil_forward(tmp_socket): client = await engine.make_client() - # Fast health probe. - fast_health_probe_task = asyncio.create_task( - client.run_check_health_loop(timeout=1.0)) - # Server should be healthy after initial probe. await asyncio.sleep(2.0) await client.check_health() # Throws an error in first forward pass. - try: + with pytest.raises(RAISED_ERROR): async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass - except Exception as e: - # First exception should be a RAISED_ERROR - assert repr(e) == repr(RAISED_ERROR) - assert client.errored + assert client.errored # Engine is errored, should get ENGINE_DEAD_ERROR. - try: + with pytest.raises(MQEngineDeadError): async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass - except Exception as e: - # Next exception should be an ENGINE_DEAD_ERROR - assert client.errored, "Client should be dead." - assert isinstance(e, MQEngineDeadError), ( - "Engine should be dead and raise ENGINE_DEAD_ERROR") + assert client.errored - await asyncio.sleep(2.0) - try: + await asyncio.sleep(1.0) + with pytest.raises(RAISED_ERROR): await client.check_health() - except Exception as e: - assert repr(e) == repr(RAISED_ERROR), ( - "Health check raise the original error.") + assert client.errored - # Cleanup - await fast_health_probe_task + # Shutdown. client.close() @@ -120,25 +108,18 @@ async def test_failed_health_check(tmp_socket): # Health probe should throw RAISED_ERROR. await asyncio.sleep(10) - try: + + with pytest.raises(RAISED_ERROR): await client.check_health() - except Exception as e: - assert client.errored, "Client should be dead." - assert repr(e) == repr(RAISED_ERROR), ( - "Health check raise the original error.") + assert client.errored # Generate call should throw ENGINE_DEAD_ERROR - try: + with pytest.raises(MQEngineDeadError): async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass - except Exception as e: - assert client.errored, "Client should be dead." - assert isinstance(e, MQEngineDeadError), ( - "Engine should be dead and raise ENGINE_DEAD_ERROR") - # Cleanup client.close() @@ -173,34 +154,26 @@ async def bad_abort_after_2s(): await asyncio.sleep(2.0) await client.abort(request_id="foo") - # Immediately should trigger error. - try: - await client.check_health() - except Exception as e: - assert client.errored, "Client should be dead." - assert repr(e) == repr(RAISED_ERROR), ( - "Health check raise the original error.") - # Trigger an abort in 2s from now. abort_task = asyncio.create_task(bad_abort_after_2s()) # Exception in abort() will happen during this generation. - # This will kill the engine and should return ENGINE_DEAD_ERROR. - try: + # This will kill the engine and should return ENGINE_DEAD_ERROR + # with reference to the original KeyError("foo") + with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( inputs="Hello my name is", sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass - except Exception as e: - print(f"error is: {e}") - # Next exception should be an ENGINE_DEAD_ERROR - assert isinstance(e, MQEngineDeadError), ( - "Engine should be dead and raise ENGINE_DEAD_ERROR") - assert client.errored - + assert "KeyError" in repr(execinfo.value) + assert client.errored await abort_task + # This should raise the original error. + with pytest.raises(RAISED_ERROR): + await client.check_health() + client.close() @@ -211,20 +184,14 @@ async def test_bad_request(tmp_socket): client = await engine.make_client() - # This should fail, but not crash the server. - try: - print("calling first generate") + # Invalid request should fail, but not crash the server. + with pytest.raises(ValueError): async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( - "invalid-lora", 1, - "invalid-path")): + "invalid-lora", 1, "invalid-path")): pass - except Exception as e: - print("got exception") - assert isinstance(e, ValueError), ( - "Expected ValueError when a LoRARequest in llm_engine") # This request should be okay. async for _ in client.generate(inputs="Hello my name is", @@ -232,10 +199,6 @@ async def test_bad_request(tmp_socket): request_id="abcd-2"): pass - # Confirm server is still running. - await asyncio.sleep(10.) - await client.check_health() - # Shutdown. client.close() diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index cf0b36bee7aa..df9941b3eb12 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -71,4 +71,4 @@ def ENGINE_DEAD_ERROR( return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " - f"find the original error {repr(error)}.") + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index e7af5aed9013..b380de3f7926 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -165,10 +165,10 @@ async def run_check_health_loop(self, timeout: int): logger.debug("Shutting down MQLLMEngineClient check health loop.") except Exception as e: - self.raise_exception(e) + self._set_errored(e) async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to request Queues""" + """Get RequestOutputs from Engine and stream to Request Queues""" try: while True: @@ -249,11 +249,7 @@ async def setup(self): def close(self): """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.output_socket.close() - self.input_socket.close() - self.health_socket.close() + # Close all sockets and terminate the context. self.context.destroy(linger=0) # Cancel background tasks. @@ -261,7 +257,7 @@ def close(self): self.health_loop.cancel() self.output_loop.cancel() - def raise_exception(self, e: BaseException): + def _set_errored(self, e: BaseException): logger.exception(repr(e)) if self._errored_with is None: self._errored_with = e @@ -285,19 +281,10 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, frame = await socket.recv(copy=False) data = pickle.loads(frame.buffer) - if isinstance(data, Exception): - # Re-raise exceptions returned by the server + if isinstance(data, BaseException): raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) + elif not isinstance(data, expected_type): + raise ValueError(error_message) return data @@ -305,7 +292,7 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, async def _send_one_way_rpc_request(request: RPC_REQUEST_T, socket: Socket): """Send one-way RPC request to trigger an action.""" - # Raise handlable error for graceful shutdown. + if socket.closed: raise MQClientClosedError() @@ -313,7 +300,7 @@ async def _send_one_way_rpc_request(request: RPC_REQUEST_T, async def _await_ack(self, error_message: str, socket: Socket): """Await acknowledgement that a request succeeded.""" - # Raise handlable error for graceful shutdown. + if socket.closed: raise MQClientClosedError() @@ -325,17 +312,19 @@ async def _await_ack(self, error_message: str, socket: Socket): @staticmethod async def _check_success(error_message: str, socket: Socket): - # Raise handlable error for graceful shutdown. + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + if socket.closed: raise MQClientClosedError() frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, BaseException): - logger.error(error_message) - raise response + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) or + response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) async def get_tokenizer(self, lora_request: LoRARequest): diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index df13532dcd04..7475e0931bd2 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -141,9 +141,7 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" - self.input_socket.close() - self.output_socket.close() - self.health_socket.close() + # Closes all sockets and destroys context. self.ctx.destroy(linger=0) del self.engine From 7097e05b30741fb675a4370f1e0fd9062669f1d0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 11 Sep 2024 23:33:39 +0000 Subject: [PATCH 089/116] fix up tests --- tests/mq_llm_engine/test_error_handling.py | 18 ++++++++++++------ vllm/engine/multiprocessing/client.py | 6 +++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 27ebde7fef0e..0da145e6cfc6 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -9,6 +9,7 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing.engine import MQLLMEngine from vllm.entrypoints.openai.api_server import build_async_engine_client @@ -168,12 +169,13 @@ async def bad_abort_after_2s(): pass assert "KeyError" in repr(execinfo.value) assert client.errored + await abort_task # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() - + client.close() @@ -190,7 +192,8 @@ async def test_bad_request(tmp_socket): sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( - "invalid-lora", 1, "invalid-path")): + "invalid-lora", 1, + "invalid-path")): pass # This request should be okay. @@ -204,14 +207,17 @@ async def test_bad_request(tmp_socket): @pytest.mark.asyncio -async def test_mp_crash_detection(): +async def test_mp_crash_detection(monkeypatch): parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") parser = make_arg_parser(parser) args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 + + # When LLMEngine is loaded, it will crash. + def mock_init(): + raise ValueError + + monkeypatch.setattr(LLMEngine, "__init__", mock_init) start = time.perf_counter() async with build_async_engine_client(args): diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b380de3f7926..bc93dc34fdf7 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -11,7 +11,7 @@ from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket -from vllm.config import DecodingConfig, EngineConfig, LoRAConfig, ModelConfig +from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable @@ -323,8 +323,8 @@ async def _check_success(error_message: str, socket: Socket): # Raise error if unsuccessful if isinstance(response, BaseException): raise response - elif (not isinstance(response, str) or - response != VLLM_RPC_SUCCESS_STR): + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) async def get_tokenizer(self, lora_request: LoRARequest): From ad3d0f87eedb37ba913c6932565d60c322c308b1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 00:11:26 +0000 Subject: [PATCH 090/116] added abort tests --- tests/mq_llm_engine/test_abort.py | 89 +++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index e69de29bb2d1..8e891d14af71 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -0,0 +1,89 @@ +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "Qwen/Qwen2-0.5B-Instruct" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_id_to_be_aborted = "request-aborted" + request_ids_a = [f"request-a-{idx}" for idx in range(10)] + request_ids_b = [f"request-b-{idx}" for idx in range(10)] + + async def run_to_completion(request_id) -> bool: + EXPECTED = 250 + count = 0 + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams( + max_tokens=EXPECTED, + temperature=0), + request_id=request_id): + count += 1 + await asyncio.sleep(0.) + + # Confirm we generated all the tokens we expected. + return count == EXPECTED + + async def run_to_be_aborted(request_id): + EXPECTED = 250 + count = 0 + try: + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams( + max_tokens=EXPECTED, + temperature=0), + request_id=request_id): + count += 1 + await asyncio.sleep(0.) + + # Confirm this was actually stopped. + except asyncio.CancelledError: + assert (count < EXPECTED) + + # Create concurrent requests. + tasks_a = [ + asyncio.create_task(run_to_completion(request_id)) + for request_id in request_ids_a + ] + task_aborted = asyncio.create_task( + run_to_be_aborted(request_id_to_be_aborted)) + tasks_b = [ + asyncio.create_task(run_to_completion(request_id)) + for request_id in request_ids_b + ] + + await asyncio.sleep(0.5) + await client.abort(request_id_to_be_aborted) + + # Confirm that we got all the EXPECTED tokens from the requests. + for task in tasks_a: + assert (await task), "Expected this task to run to completion." + for task in tasks_b: + assert (await task), "Expected this task to run to completion." + + # Cancel task (this will hang indefinitely if not). + task_aborted.cancel() + + # Shutdown. + client.close() From 6e9c6c9c7a1c7bb167c7cf669af060cc9789d728 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 00:16:48 +0000 Subject: [PATCH 091/116] added another accurayc test --- tests/entrypoints/openai/test_accuracy.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b442a903c33a..cf6dbcae131d 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -21,11 +21,12 @@ @pytest.fixture(scope="module") -def server(): - args = [ - "--max-model-len", "4096", "--enable-chunked-prefill", - "--disable-log-requests", "--enforce-eager" - ] +@pytest.mark.parametrize( + "more_args", + [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]) +def server(more_args): + args = ["--max-model-len", "4096", "--disable-log-requests"] + args.extend(more_args) with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server From fb8e2f93b90638131462ee66d7cf1a40fa7f0415 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 00:28:21 +0000 Subject: [PATCH 092/116] add multistep test for accuracy of mq llm engine --- tests/entrypoints/openai/test_accuracy.py | 54 +++++++++-------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index cf6dbcae131d..88a3341e2d97 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -18,39 +18,29 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], + ["--num-scheduler-steps", "8"]] - -@pytest.fixture(scope="module") -@pytest.mark.parametrize( - "more_args", - [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]) -def server(more_args): - args = ["--max-model-len", "4096", "--disable-log-requests"] +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy(more_args): + args = DEFAULT_ARGS args.extend(more_args) with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_data(server): - return { - "url": f"{server.url_for('v1')}/completions", - } - - -def test_lm_eval_accuracy(server_data): - model_args = (f"model={MODEL_NAME}," - f"base_url={server_data['url']}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") - - results = lm_eval.simple_evaluate( - model="local-completions", - model_args=model_args, - tasks=TASK, - ) - - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + url = f"{remote_server.url_for('v1')}/completions" + + model_args = (f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" From 75523b208beb53e9e2f39cf9c10833a6947cd352 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 00:29:00 +0000 Subject: [PATCH 093/116] added test genertion --- tests/mq_llm_engine/test_generation.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/mq_llm_engine/test_generation.py diff --git a/tests/mq_llm_engine/test_generation.py b/tests/mq_llm_engine/test_generation.py new file mode 100644 index 000000000000..e69de29bb2d1 From 5546d2ed5b0b1dc235bd03503de26d32f19a41ce Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 00:30:37 +0000 Subject: [PATCH 094/116] fixed accuracy test launch --- tests/entrypoints/openai/test_accuracy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 88a3341e2d97..e731a051dfd1 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -24,12 +24,14 @@ @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) def test_lm_eval_accuracy(more_args): - args = DEFAULT_ARGS + args = list(DEFAULT_ARGS) args.extend(more_args) + print(f"Running with: {args}") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: url = f"{remote_server.url_for('v1')}/completions" - + model_args = (f"model={MODEL_NAME}," f"base_url={url}," f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") From 6403f49860c4dde65a17e677ab825fb121698129 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 01:06:36 +0000 Subject: [PATCH 095/116] added load test --- tests/entrypoints/openai/test_accuracy.py | 13 +++++++------ tests/mq_llm_engine/test_generation.py | 0 2 files changed, 7 insertions(+), 6 deletions(-) delete mode 100644 tests/mq_llm_engine/test_generation.py diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index e731a051dfd1..2ad8460023c2 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,8 +19,8 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -MORE_ARGS_LIST = [["--enable-chunked-prefill"], - ["--num-scheduler-steps", "8"]] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] + @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) def test_lm_eval_accuracy(more_args): @@ -32,9 +32,10 @@ def test_lm_eval_accuracy(more_args): with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: url = f"{remote_server.url_for('v1')}/completions" - model_args = (f"model={MODEL_NAME}," - f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + model_args = ( + f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") results = lm_eval.simple_evaluate( model="local-completions", @@ -45,4 +46,4 @@ def test_lm_eval_accuracy(more_args): measured_value = results["results"][TASK][FILTER] assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/mq_llm_engine/test_generation.py b/tests/mq_llm_engine/test_generation.py deleted file mode 100644 index e69de29bb2d1..000000000000 From 3bb5e52e0dccdf4a479b3f7d1617e7c81a17a949 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 15:12:39 +0000 Subject: [PATCH 096/116] remove file --- .../openai/test_openapi_server_ray.py | 111 ------------------ 1 file changed, 111 deletions(-) delete mode 100644 tests/entrypoints/openai/test_openapi_server_ray.py diff --git a/tests/entrypoints/openai/test_openapi_server_ray.py b/tests/entrypoints/openai/test_openapi_server_ray.py deleted file mode 100644 index 1c5f645be121..000000000000 --- a/tests/entrypoints/openai/test_openapi_server_ray.py +++ /dev/null @@ -1,111 +0,0 @@ -import openai # use the official client for correctness check -import pytest -import pytest_asyncio - -from ...utils import VLLM_PATH, RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - - -@pytest.fixture(scope="module") -def server(): - args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--engine-use-ray", - "--chat-template", - str(chatml_jinja_path), - ] - - # Allow `--engine-use-ray`, otherwise the launch of the server throw - # an error due to try to use a deprecated feature - env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"} - with RemoteOpenAIServer(MODEL_NAME, args, - env_dict=env_dict) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.asyncio -async def test_single_completion(client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 5 - - -@pytest.mark.asyncio -async def test_single_chat_session(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert len(chat_completion.choices) == 1 - - choice = chat_completion.choices[0] - assert choice.finish_reason == "length" - assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=55, total_tokens=65) - - message = choice.message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 From 2ac814f91fda70509cbd4c965307a525fdac27d5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 15:14:57 +0000 Subject: [PATCH 097/116] format --- tests/async_engine/test_openapi_server.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/async_engine/test_openapi_server.py diff --git a/tests/async_engine/test_openapi_server.py b/tests/async_engine/test_openapi_server.py deleted file mode 100644 index e69de29bb2d1..000000000000 From 179a667c9534f445f25f988d613d2b633371c9a8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 15:49:21 +0000 Subject: [PATCH 098/116] added load test --- vllm/engine/multiprocessing/client.py | 3 +-- vllm/engine/multiprocessing/engine.py | 10 ++++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index bc93dc34fdf7..c2897a987d7d 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -120,8 +120,7 @@ def is_unsupported_config(engine_args: AsyncEngineArgs): seed=0, dtype="auto").embedding_mode is_pp = engine_args.pipeline_parallel_size > 1 - is_engine_use_ray = engine_args.engine_use_ray - return is_embedding or is_pp or is_engine_use_ray + return is_embedding or is_pp @contextmanager def get_data_socket(self) -> Iterator[Socket]: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 7475e0931bd2..e89e8dd2aabf 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -107,12 +107,6 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, engine_config = engine_args.create_engine_config() - if engine_args.engine_use_ray: - raise NotImplementedError( - "--engine-use-ray is not supported for MQLLMEngine. " - "Launch with --disable-frontend-multiprocessing if you " - "need to deploy with this flag (not recommended).") - executor_class = LLMEngine._get_executor_cls(engine_config) return cls( @@ -199,6 +193,10 @@ def run_engine_loop(self): # Engine step. request_outputs = self.engine_step() + for request_output in request_outputs: + if request_output.request_id == "request-211": + print("\n\n\n") + print(request_output) # Send request outputs (if async, done in engine_step callback). if not self.use_async_sockets: From 97d6c096381bb95e96590fc14d4006395cb23625 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 15:53:00 +0000 Subject: [PATCH 099/116] format --- tests/mq_llm_engine/test_abort.py | 2 ++ tests/mq_llm_engine/test_error_handling.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index 8e891d14af71..d973283e3e3f 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -1,3 +1,5 @@ +"""Test that aborting is handled properly.""" + import asyncio import tempfile import uuid diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 0da145e6cfc6..16e82794513c 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -1,3 +1,5 @@ +"""Test that various errors are handled properly.""" + import asyncio import tempfile import time From 78badc17151db7beb5c98e15f47c7d940adc48c8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 15:53:51 +0000 Subject: [PATCH 100/116] added load test --- tests/mq_llm_engine/test_load.py | 74 ++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/mq_llm_engine/test_load.py diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py new file mode 100644 index 000000000000..b8a1782f8694 --- /dev/null +++ b/tests/mq_llm_engine/test_load.py @@ -0,0 +1,74 @@ +"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient + +MODEL = "Qwen/Qwen2-0.5B-Instruct" +NUM_EXPECTED_TOKENS = 10 +PROMPT = "Hello my name is Robert and I love" +NUM_REQUESTS = 10000 + +# Scenarios to test for num generated token. +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + +async def run_to_completion(client: MQLLMEngineClient, + request_id: str): + + count = 0 + async for out in client.generate(inputs=PROMPT, + sampling_params=SamplingParams( + max_tokens=NUM_EXPECTED_TOKENS, + temperature=0), + request_id=request_id): + + count += 1 + await asyncio.sleep(0.) + + # Confirm we generated all the tokens we expected. + return count, request_id + +@pytest.mark.asyncio +async def test_generation(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append(asyncio.create_task(run_to_completion(client, + request_id))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS and + failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + # Shutdown. + client.close() From a4997339690eeaa2a6ced575078803a872bbd23f Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 12 Sep 2024 16:33:28 +0000 Subject: [PATCH 101/116] format --- tests/mq_llm_engine/test_load.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index b8a1782f8694..3725d93e56f3 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -7,7 +7,6 @@ import pytest from tests.mq_llm_engine.utils import RemoteMQLLMEngine - from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -20,27 +19,29 @@ # Scenarios to test for num generated token. ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + @pytest.fixture(scope="function") def tmp_socket(): with tempfile.TemporaryDirectory() as td: yield f"ipc://{td}/{uuid.uuid4()}" -async def run_to_completion(client: MQLLMEngineClient, - request_id: str): - + +async def run_to_completion(client: MQLLMEngineClient, request_id: str): + count = 0 async for out in client.generate(inputs=PROMPT, - sampling_params=SamplingParams( - max_tokens=NUM_EXPECTED_TOKENS, - temperature=0), - request_id=request_id): - + sampling_params=SamplingParams( + max_tokens=NUM_EXPECTED_TOKENS, + temperature=0), + request_id=request_id): + count += 1 await asyncio.sleep(0.) # Confirm we generated all the tokens we expected. return count, request_id + @pytest.mark.asyncio async def test_generation(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, @@ -53,16 +54,16 @@ async def test_generation(tmp_socket): # Create concurrent requests. tasks = [] for request_id in request_ids: - tasks.append(asyncio.create_task(run_to_completion(client, - request_id))) + tasks.append( + asyncio.create_task(run_to_completion(client, request_id))) # Confirm that we got all the EXPECTED tokens from the requests. failed_request_id = None tokens = None for task in tasks: num_generated_tokens, request_id = await task - if (num_generated_tokens != NUM_EXPECTED_TOKENS and - failed_request_id is None): + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): failed_request_id = request_id tokens = num_generated_tokens From 6a5d8d812d7d40f8c97d8a4d29aa1c56c3f88fb2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 20:20:14 +0000 Subject: [PATCH 102/116] stash --- tests/mq_llm_engine/test_load.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index b8a1782f8694..d705e370bf21 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -18,7 +18,8 @@ NUM_REQUESTS = 10000 # Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True, + enable_chunked_prefill=True) @pytest.fixture(scope="function") def tmp_socket(): From 96f84feb30e6087b3e55bdbb0c0279e812533583 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Sep 2024 20:38:47 +0000 Subject: [PATCH 103/116] format --- tests/mq_llm_engine/test_abort.py | 68 ++++++++++--------------------- tests/mq_llm_engine/test_load.py | 27 ++---------- tests/mq_llm_engine/utils.py | 20 ++++++++- 3 files changed, 45 insertions(+), 70 deletions(-) diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index d973283e3e3f..68cebf6569dd 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -6,14 +6,14 @@ import pytest -from tests.mq_llm_engine.utils import RemoteMQLLMEngine -from vllm import SamplingParams +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs MODEL = "Qwen/Qwen2-0.5B-Instruct" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) RAISED_ERROR = KeyError RAISED_VALUE = "foo" +EXPECTED_TOKENS = 250 @pytest.fixture(scope="function") @@ -33,56 +33,32 @@ async def test_abort(tmp_socket): request_ids_a = [f"request-a-{idx}" for idx in range(10)] request_ids_b = [f"request-b-{idx}" for idx in range(10)] - async def run_to_completion(request_id) -> bool: - EXPECTED = 250 - count = 0 - async for _ in client.generate(inputs="Hello my name is", - sampling_params=SamplingParams( - max_tokens=EXPECTED, - temperature=0), - request_id=request_id): - count += 1 - await asyncio.sleep(0.) - - # Confirm we generated all the tokens we expected. - return count == EXPECTED - - async def run_to_be_aborted(request_id): - EXPECTED = 250 - count = 0 - try: - async for _ in client.generate(inputs="Hello my name is", - sampling_params=SamplingParams( - max_tokens=EXPECTED, - temperature=0), - request_id=request_id): - count += 1 - await asyncio.sleep(0.) - - # Confirm this was actually stopped. - except asyncio.CancelledError: - assert (count < EXPECTED) - - # Create concurrent requests. - tasks_a = [ - asyncio.create_task(run_to_completion(request_id)) - for request_id in request_ids_a - ] + # Requests started before one to be aborted. + tasks = [] + for request_id in request_ids_a: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Aborted. task_aborted = asyncio.create_task( - run_to_be_aborted(request_id_to_be_aborted)) - tasks_b = [ - asyncio.create_task(run_to_completion(request_id)) - for request_id in request_ids_b - ] + generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) + + # Requests started after one to be aborted. + for request_id in request_ids_b: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + # Actually abort. await asyncio.sleep(0.5) await client.abort(request_id_to_be_aborted) # Confirm that we got all the EXPECTED tokens from the requests. - for task in tasks_a: - assert (await task), "Expected this task to run to completion." - for task in tasks_b: - assert (await task), "Expected this task to run to completion." + for task in tasks: + count, request_id = await task + assert count == EXPECTED_TOKENS, ( + f"{request_id} generated only {count} tokens") # Cancel task (this will hang indefinitely if not). task_aborted.cancel() diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index 4921f76745d6..37273f735e5f 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -6,19 +6,15 @@ import pytest -from tests.mq_llm_engine.utils import RemoteMQLLMEngine -from vllm import SamplingParams +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing.client import MQLLMEngineClient MODEL = "Qwen/Qwen2-0.5B-Instruct" NUM_EXPECTED_TOKENS = 10 -PROMPT = "Hello my name is Robert and I love" NUM_REQUESTS = 10000 # Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True, - enable_chunked_prefill=True) +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) @pytest.fixture(scope="function") @@ -27,22 +23,6 @@ def tmp_socket(): yield f"ipc://{td}/{uuid.uuid4()}" -async def run_to_completion(client: MQLLMEngineClient, request_id: str): - - count = 0 - async for out in client.generate(inputs=PROMPT, - sampling_params=SamplingParams( - max_tokens=NUM_EXPECTED_TOKENS, - temperature=0), - request_id=request_id): - - count += 1 - await asyncio.sleep(0.) - - # Confirm we generated all the tokens we expected. - return count, request_id - - @pytest.mark.asyncio async def test_generation(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, @@ -56,7 +36,8 @@ async def test_generation(tmp_socket): tasks = [] for request_id in request_ids: tasks.append( - asyncio.create_task(run_to_completion(client, request_id))) + asyncio.create_task( + generate(client, request_id, NUM_EXPECTED_TOKENS))) # Confirm that we got all the EXPECTED tokens from the requests. failed_request_id = None diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 0b00c4209560..3eafb9807368 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -1,12 +1,30 @@ +import asyncio import multiprocessing -from typing import Callable +from typing import Callable, Tuple +from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import MQLLMEngine from vllm.usage.usage_lib import UsageContext +async def generate(client: MQLLMEngineClient, request_id: str, + num_tokens: int) -> Tuple[int, str]: + + count = 0 + async for _ in client.generate(request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams( + max_tokens=num_tokens, temperature=0)): + + count += 1 + await asyncio.sleep(0.) + + # Confirm we generated all the tokens we expected. + return count, request_id + + def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): # Make engine. engine = MQLLMEngine.from_engine_args( From 117c02402fc1cddf96173662c6b7e5b9cccd3dab Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 14 Sep 2024 14:05:23 +0000 Subject: [PATCH 104/116] format --- tests/mq_llm_engine/test_load.py | 2 +- tests/mq_llm_engine/utils.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index 37273f735e5f..e5c1e6a824e7 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -24,7 +24,7 @@ def tmp_socket(): @pytest.mark.asyncio -async def test_generation(tmp_socket): +async def test_load(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, ipc_path=tmp_socket) as engine: diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3eafb9807368..e27fd7792341 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -1,26 +1,36 @@ import asyncio import multiprocessing -from typing import Callable, Tuple +from typing import Callable, Tuple, Union from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext -async def generate(client: MQLLMEngineClient, request_id: str, - num_tokens: int) -> Tuple[int, str]: +async def generate( + client: MQLLMEngineClient, + request_id: str, + num_tokens: int, + return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + final_output = None count = 0 - async for _ in client.generate(request_id=request_id, - inputs="Hello my name is Robert and", - sampling_params=SamplingParams( - max_tokens=num_tokens, temperature=0)): + async for out in client.generate( + request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams(max_tokens=num_tokens, + temperature=0)): count += 1 + final_output = out await asyncio.sleep(0.) + if return_output: + return final_output + # Confirm we generated all the tokens we expected. return count, request_id From c059713f09c4fac26f429b1baf971d0c82c52ef0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 14 Sep 2024 14:15:22 +0000 Subject: [PATCH 105/116] remove debug print --- vllm/engine/multiprocessing/engine.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index e89e8dd2aabf..022860130ca6 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -193,10 +193,6 @@ def run_engine_loop(self): # Engine step. request_outputs = self.engine_step() - for request_output in request_outputs: - if request_output.request_id == "request-211": - print("\n\n\n") - print(request_output) # Send request outputs (if async, done in engine_step callback). if not self.use_async_sockets: From 1af3297c32adf461d9389b873229fceaca9cc9b1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 14 Sep 2024 14:21:07 +0000 Subject: [PATCH 106/116] removed stray --- vllm/entrypoints/openai/serving_chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b8bf64be6981..e4f1c834b910 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -220,7 +220,6 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - print("calling generate") result_generator = self.engine_client.generate( engine_inputs, sampling_params, From 97ae38d5d3af3c28ad98523839f62e925e1cc73a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 14 Sep 2024 14:24:28 +0000 Subject: [PATCH 107/116] updated --- examples/openai_chat_completion_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/openai_chat_completion_client.py b/examples/openai_chat_completion_client.py index bbada3891bd1..a7925f345709 100644 --- a/examples/openai_chat_completion_client.py +++ b/examples/openai_chat_completion_client.py @@ -2,7 +2,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:8001/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") From d0fab110352dfe4434ec7374870e573bf7bb05c1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 14 Sep 2024 16:29:04 +0000 Subject: [PATCH 108/116] switch model to avoid OOM in TPU test --- tests/mq_llm_engine/test_abort.py | 2 +- tests/mq_llm_engine/test_error_handling.py | 2 +- tests/mq_llm_engine/test_load.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index 68cebf6569dd..782b508a5714 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -9,7 +9,7 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "Qwen/Qwen2-0.5B-Instruct" +MODEL = "google/gemma-1.1-2b-it" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) RAISED_ERROR = KeyError RAISED_VALUE = "foo" diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 16e82794513c..ddd42494f20b 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser -MODEL = "Qwen/Qwen2-0.5B-Instruct" +MODEL = "google/gemma-1.1-2b-it" ENGINE_ARGS = AsyncEngineArgs(model=MODEL) RAISED_ERROR = KeyError RAISED_VALUE = "foo" diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index e5c1e6a824e7..630c112d0f0c 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -9,7 +9,7 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "Qwen/Qwen2-0.5B-Instruct" +MODEL = "google/gemma-1.1-2b-it" NUM_EXPECTED_TOKENS = 10 NUM_REQUESTS = 10000 From 1967f6a22041beba7bf33c67290c9d223e16180a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 16 Sep 2024 13:51:59 -0700 Subject: [PATCH 109/116] Adjust timeouts --- tests/entrypoints/openai/test_shutdown.py | 2 +- tests/utils.py | 2 +- vllm/engine/multiprocessing/engine.py | 4 +++- vllm/entrypoints/openai/api_server.py | 5 ++++- vllm/executor/multiproc_worker_utils.py | 4 ++++ 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 73ecb7400727..25ab91ef6933 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path): prompt="Hello, my name is") # Now the server should shut down - return_code = remote_server.proc.wait(timeout=3) + return_code = remote_server.proc.wait(timeout=8) assert return_code is not None diff --git a/tests/utils.py b/tests/utils.py index f6c2be17ebdc..81442cad78da 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,7 +119,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: - self.proc.wait(3) + self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 26f4901338c7..94bcdb7d9a40 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -204,7 +204,9 @@ def engine_step(self) -> List[RequestOutput]: try: return self.engine.step() - except Exception as e: + except SystemExit: + raise + except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, is_engine_errored=True, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c2868d599b0f..b263384dd377 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -201,7 +201,10 @@ async def build_async_engine_client_from_engine_args( mp_engine_client.close() # Wait for engine process to join - engine_process.join() + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index aa2a16c04d08..5bef76b90d33 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -168,6 +168,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -222,6 +224,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise except KeyboardInterrupt: break except BaseException as e: From a91132311abdde84f533d82e4bcf4cfcfff92379 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 17 Sep 2024 02:29:22 +0000 Subject: [PATCH 110/116] stahs --- vllm/engine/multiprocessing/__init__.py | 1 - vllm/engine/multiprocessing/client.py | 8 ------ vllm/engine/multiprocessing/engine.py | 37 ++++++++++--------------- vllm/envs.py | 2 +- 4 files changed, 15 insertions(+), 33 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index df9941b3eb12..ba5c6e15fc82 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -48,7 +48,6 @@ class RPCHealthRequest: class RPCStartupRequest(Enum): IS_SERVER_READY = 1 - CLIENT_IS_READY = 2 @dataclass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index d53a8b9409e5..1bbfa03a3153 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -247,8 +247,6 @@ async def setup(self): self.health_loop = asyncio.create_task( self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) - # Notify MQLLMEngine client is ready to start sending requests. - await self._notify_ready(socket) def close(self): """Destroy the ZeroMQ Context.""" @@ -351,12 +349,6 @@ async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: error_message="Unable to start RPC Server", socket=socket) - async def _notify_ready(self, socket: Socket): - """Get the RPCServer that the RPCClient is ready""" - - await self._send_one_way_rpc_request( - request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) - async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 94bcdb7d9a40..52b4e2fb8a27 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -154,30 +154,21 @@ def run_startup_loop(self) -> None: """Startup loop for sending data from Engine -> Client.""" with self.make_data_socket() as socket: + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e - # Loop until the RPCClient has all the data it needs. - client_is_ready = False - while not client_is_ready: - response: Any - try: - identity, message = socket.recv_multipart(copy=False) - request: RPCStartupRequest = pickle.loads(message.buffer) - - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) - elif request == RPCStartupRequest.CLIENT_IS_READY: - response = VLLM_RPC_SUCCESS_STR - # Breakout of loop once client is ready. - client_is_ready = True - - except Exception as e: - response = e - - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) def run_engine_loop(self): """Core busy loop of the LLMEngine.""" diff --git a/vllm/envs.py b/vllm/envs.py index 740af8341d2b..fc9753c4c413 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_TIMEOUT: int = 5000 + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False From 95ff4f3e26e36dd34eaae0890902a56f990a97bd Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 17 Sep 2024 02:48:17 +0000 Subject: [PATCH 111/116] make timeout 10000 ms --- vllm/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index fc9753c4c413..a4bde548a23d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -393,7 +393,7 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations "VLLM_RPC_TIMEOUT": - lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "5000")), + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded From 302868ec42619aa794c71f5c316f2a0e3de3e520 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 17 Sep 2024 02:58:41 +0000 Subject: [PATCH 112/116] format --- vllm/engine/multiprocessing/client.py | 1 - vllm/engine/multiprocessing/engine.py | 5 +++-- vllm/envs.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 1bbfa03a3153..18b620c74ddf 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -247,7 +247,6 @@ async def setup(self): self.health_loop = asyncio.create_task( self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) - def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets and terminate the context. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 52b4e2fb8a27..70cd6e5cb600 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,7 +1,7 @@ import pickle import signal from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Union import cloudpickle import zmq @@ -154,6 +154,7 @@ def run_startup_loop(self) -> None: """Startup loop for sending data from Engine -> Client.""" with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] try: identity, message = socket.recv_multipart(copy=False) request: RPCStartupRequest = pickle.loads(message.buffer) @@ -168,7 +169,7 @@ def run_startup_loop(self) -> None: response = e socket.send_multipart((identity, pickle.dumps(response)), - copy=False) + copy=False) def run_engine_loop(self): """Core busy loop of the LLMEngine.""" diff --git a/vllm/envs.py b/vllm/envs.py index a4bde548a23d..262e56869e88 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False From add68ee763c54fc325beca4e49fb2a3ed664756b Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:22:56 -0400 Subject: [PATCH 113/116] Update examples/openai_chat_completion_client.py Co-authored-by: Simon Mo --- examples/openai_chat_completion_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/openai_chat_completion_client.py b/examples/openai_chat_completion_client.py index a7925f345709..bbada3891bd1 100644 --- a/examples/openai_chat_completion_client.py +++ b/examples/openai_chat_completion_client.py @@ -2,7 +2,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8001/v1" +openai_api_base = "http://localhost:8000/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") From 242b952275d88b1b73eee57448c5c9567e050080 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 17 Sep 2024 14:22:05 +0000 Subject: [PATCH 114/116] adjust RPC timeout on TPU --- tests/tpu/test_custom_dispatcher.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 7f3fb595321a..69ab67abdd12 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,12 @@ +import os + from ..utils import compare_two_settings +# --enforce-eager on TPU causes graph compilation +# this times out default Health Check in the MQLLMEngine, +# so we set the timeout here to 30s +os.environ["VLLM_RPC_TIMEOUT"] = "30000" + def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", From 3dafa26d830b1c339f84f93c61a25c0f89f7305a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 17 Sep 2024 20:24:21 +0000 Subject: [PATCH 115/116] add longer delay for check ehalth --- tests/mq_llm_engine/test_error_handling.py | 2 +- vllm/engine/multiprocessing/client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index ddd42494f20b..49cfc5aa04c3 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -110,7 +110,7 @@ async def test_failed_health_check(tmp_socket): assert client.is_running # Health probe should throw RAISED_ERROR. - await asyncio.sleep(10) + await asyncio.sleep(15.) with pytest.raises(RAISED_ERROR): await client.check_health() diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 18b620c74ddf..52e52ff392f8 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -365,6 +365,7 @@ async def check_health(self): Engine's health every N seconds and sets _errored_with if the engine is unhealthy. """ + print(self._errored_with) if self._errored_with is not None: raise self._errored_with From 836a9d2d54b90a170ff7ad8b4d005f87ffc334b8 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:07:01 -0400 Subject: [PATCH 116/116] Update client.py --- vllm/engine/multiprocessing/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 52e52ff392f8..18b620c74ddf 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -365,7 +365,6 @@ async def check_health(self): Engine's health every N seconds and sets _errored_with if the engine is unhealthy. """ - print(self._errored_with) if self._errored_with is not None: raise self._errored_with