diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 57ea3cd23f38..d521382e97b0 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -77,6 +77,7 @@ def run_vllm( gpu_memory_utilization: float = 0.9, ) -> float: from vllm import LLM, SamplingParams + print(f"Ready to initialize vLLM") llm = LLM(model=model, tokenizer=tokenizer, quantization=quantization, @@ -129,6 +130,7 @@ def run_vllm( start = time.perf_counter() # FIXME(woosuk): Do not use internal method. + # It might be the problem of _run_engine llm._run_engine(use_tqdm=True) end = time.perf_counter() return end - start diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 6460164aa34b..fc3037dbec49 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,10 +11,12 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=128) # Create an LLM. -llm = LLM(model="YOUR_MODEL", device="xpu", enforce_eager=True, dtype="float16", gpu_memory_utilization=0.80, trust_remote_code=True) +llm = LLM(model="/llm/models/Llama-2-7b-chat-hf", device="xpu", enforce_eager=True, dtype="float16", gpu_memory_utilization=0.80, trust_remote_code=True, tensor_parallel_size=2) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +llm.clean_up() \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 10758aaa4310..a76c6ad5620a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -126,11 +126,15 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": engine_configs = engine_args.create_engine_configs() parallel_config = engine_configs[2] + # TODO: change later here # Initialize the cluster and specify the executor class. - if parallel_config.worker_use_ray: - initialize_ray_cluster(parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor - executor_class = RayGPUExecutor + # if parallel_config.worker_use_ray: + # initialize_ray_cluster(parallel_config) + # from vllm.executor.ray_gpu_executor import RayGPUExecutor + # executor_class = RayGPUExecutor + if parallel_config.world_size > 1: + from vllm.executor.single_node_gpu_executor import SingleNodeXpuExecutor + executor_class = SingleNodeXpuExecutor else: assert parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py new file mode 100644 index 000000000000..bb9b173a40f1 --- /dev/null +++ b/vllm/engine/local_worker_utils.py @@ -0,0 +1,256 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from io import TextIOBase +from multiprocessing.connection import wait +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp_method = os.getenv("MULTIPROC_METHOD", "fork") +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID = None + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + if self.result.exception is not None: + raise self.result.exception + return self.result.value + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for future in self.tasks.values(): + _set_future_result( + future, Result(exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['LocalWorkerVllm'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([p.sentinel for p in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + if worker.sentinel in dead_sentinels: + worker.join(1) + if worker.exitcode is not None and worker.exitcode != 0: + logger.error( + f"Worker {worker.name} pid {worker.pid} died, " + f"exit code: {worker.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + worker.clean() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.join(2) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class LocalWorkerVllm(mp.Process): + """Local process wrapper for vllm.worker.Worker + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + import intel_extension_for_pytorch as ipex + super().__init__(daemon=True) + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.worker_factory = worker_factory + self.worker = None + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.kill() + + def run(self) -> None: + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + del self.tasks # Not used in forked process + self.worker = self.worker_factory() + del self.worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(self._task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(self.worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + f"Exception in worker {mp.current_process().name} " + f"while processing method {method}: {e}, {tb}") + exception = e + self.clean() + self.result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + self.clean() + pass + except Exception: + self.clean() + logger.exception("Worker failed") + + logger.info("Worker exiting") + + def clean(self): + print(f"Performing self-cleaning job for worker") + import torch + torch.xpu.synchronize() + torch.xpu.empty_cache() + del self.worker.model_runner.model + import gc + gc.collect() + + +def _add_prefix(file: TextIOBase, worker_name: str, pid: int) -> None: + """Prepend output with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False + + file.start_new_line = True + file.write = write_with_prefix diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1f463bdaaedc..313244fdf44f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -210,3 +210,11 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) return outputs + + def clean_up(self): + if self.llm_engine.parallel_config.tensor_parallel_size == 1: + pass + else: + # Clean the cache, delete the model, etc. etc. + print("We are here") + self.llm_engine.model_executor.clean() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 30717e8a8735..28924a9cd386 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -54,6 +54,9 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def clean(self) -> None: + raise NotImplementedError + class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/executor/single_node_gpu_executor.py b/vllm/executor/single_node_gpu_executor.py new file mode 100644 index 000000000000..c8300ebcce45 --- /dev/null +++ b/vllm/executor/single_node_gpu_executor.py @@ -0,0 +1,285 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.utils import check_block_size_valid +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput +from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, + WorkerMonitor) +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +import torch + +from functools import partial + +logger = init_logger(__name__) +import os + +def get_xpu_device_type(x): + if x.device.type != "xpu": + return x.device.type + name = torch.xpu.get_device_name(x.device.index) + if name.startswith("Intel(R) Arc(TM) A"): + return "arc" + elif name.startswith("Intel(R) Arc(TM)"): + return "mtl" + elif name.startswith("Intel(R) Data Center GPU Flex"): + return "flex" + elif name.startswith("Intel(R) Data Center GPU Max"): + return "pvc" + else: + return "others" +""" +To create a new worker, we probably needs to do the following: +1. Create the executor, the executor should manage workers... +2. It is used to execute the model on devices +Need to implement the following methods: +add_lora method +remove_lora method +list_loras method +check_health method +""" + +def _create_worker(*args, **kwargs): + # Import within worker process to avoid CUDA init issues + from vllm.worker.worker import Worker + return Worker(*args, **kwargs) + +class SingleNodeXpuExecutor(ExecutorBase): + """Python multiprocessing-based multi-GPU executor""" + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + print(f"Invoked into singlenodexpuexecutor") + # TODO: change here + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + self._init_executor() + + def _init_executor(self) -> None: + # Create the parallel GPU workers. + self._init_workers() + + # Profile the memory usage and initialize the cache. + self._init_cache() + + + # TODO: implement multi-card self-selection to select cards + def _init_workers(self): + # TODO: fix the CUDA issues + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + # if "CUDA_VISIBLE_DEVICES" not in os.environ: + # set_cuda_visible_devices(range(world_size)) + + # TODO: enable device count using torch.xpu api + # from torch.cuda import device_count + # assert world_size <= device_count(), ( + # "please set tensor_parallel_size to less than max local gpu count") + + # Get a distributed_init_method + # FIXME: we probably want to do something with the proxy? + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + LocalWorkerVllm( + result_handler, + partial( + _create_worker, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + )) for rank in range(1, world_size) + ] + + for worker in self.workers: + worker.start() + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self._init_driver_worker_and_model(0, 0, distributed_init_method) + + + def _init_driver_worker_and_model(self, rank: int, local_rank: int, + distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + # Initialize the driver worker with the Worker class. + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + + # self._run_workers("init_device") + self._run_workers("init_model", + cupy_port=get_open_port() + if not self.model_config.enforce_eager else None) + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def _init_cache(self) -> None: + # TODO: fix this _init_cache + """Profiles the memory usage and initializes the KV cache. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + More details can be found in the + :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method + from class :class:`~vllm.worker.Worker`. + + Afterwards, as there may be multiple workers, + we take the minimum number of blocks across all workers + to ensure this can be applied to all of them. + + Finally, the engine will initialize the KV cache + with the calculated number of blocks. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers( + "profile_num_available_blocks", + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + # if self.cache_config.forced_num_gpu_blocks is not None: + # forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks + # logger.info(f"Replacing profiled {num_gpu_blocks=} with " + # f"{forced_num_gpu_blocks=}") + # num_gpu_blocks = forced_num_gpu_blocks + + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") + + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def clean(self): + return self._run_workers("clean") \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 5c268d63c27e..67852953d6f1 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -91,7 +91,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, from ipex_llm import optimize_model # print(model) # input("pause") - optimize_model(model) + optimize_model(model, low_bit="fp16") # print("optimized ***********************************") # print(model) model = model.to(device=device_config.device, dtype=model_config.dtype) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76c848e4b96b..dd1f2f7b3f16 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,6 +84,7 @@ def __init__( self.model_config.enforce_eager = True def load_model(self) -> None: + # TODO: add later with measure_device_memory() as m: self.model = get_model(self.model_config, self.device_config, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 183642e96f23..c30d5f4b236e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -262,6 +262,15 @@ def get_cache_block_size_bytes(self, block_size: int, self.model_config, self.parallel_config) + def clean(self): + print(f"Perform cleanup for main process") + del self.model_runner.model + import torch + torch.xpu.synchronize() + torch.xpu.empty_cache() + import gc + gc.collect() + def init_distributed_environment( parallel_config: ParallelConfig,