-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[V1] Refactor parallel sampling support #13774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ | |
| from vllm.utils import cdiv, kill_process_tree | ||
| from vllm.v1.engine.core_client import EngineCoreClient | ||
| from vllm.v1.engine.output_processor import OutputProcessor | ||
| from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async | ||
| from vllm.v1.engine.parallel_sampling import ParentRequest | ||
| from vllm.v1.engine.processor import Processor | ||
| from vllm.v1.executor.abstract import Executor | ||
| from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, | ||
|
|
@@ -145,25 +145,30 @@ async def add_request( | |
| """Add new request to the AsyncLLM.""" | ||
|
|
||
| # 1) Create a new output queue for the request. | ||
| if self.output_processor.is_request_active(request_id): | ||
| raise ValueError(f"Request id {request_id} already running.") | ||
| queue: asyncio.Queue[RequestOutput] = asyncio.Queue() | ||
|
|
||
| # 2) Convert Input --> Request. | ||
| request = self.processor.process_inputs(request_id, prompt, params, | ||
| arrival_time, lora_request, | ||
| trace_headers, | ||
| prompt_adapter_request, | ||
| priority) | ||
| # 2) Fan out child requests (for n>1) | ||
| parent_req = ParentRequest.from_params(request_id, params) | ||
| n = params.n if isinstance(params, SamplingParams) else 1 | ||
| for idx in range(n): | ||
| if parent_req is not None: | ||
| request_id, params = parent_req.get_child_info(idx) | ||
|
|
||
| # 3) Add the request to OutputProcessor (this process). | ||
| self.output_processor.add_request(request, queue) | ||
| # 3) Convert Input --> Request. | ||
| request = self.processor.process_inputs(request_id, prompt, params, | ||
| arrival_time, lora_request, | ||
| trace_headers, | ||
| prompt_adapter_request, | ||
| priority) | ||
markmc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # 4) Add the EngineCoreRequest to EngineCore (separate process). | ||
| await self.engine_core.add_request_async(request) | ||
| # 4) Add the request to OutputProcessor (this process). | ||
| self.output_processor.add_request(request, parent_req, idx, queue) | ||
|
|
||
| if self.log_requests: | ||
| logger.info("Added request %s.", request_id) | ||
| # 5) Add the EngineCoreRequest to EngineCore (separate process). | ||
| await self.engine_core.add_request_async(request) | ||
|
|
||
| if self.log_requests: | ||
| logger.info("Added request %s.", request_id) | ||
|
|
||
| return queue | ||
|
|
||
|
|
@@ -172,7 +177,7 @@ async def add_request( | |
| # requests we don't need to send multiple messages to core proc, | ||
| # and so we don't need multiple streams which then get | ||
| # re-multiplexed in the API server anyhow. | ||
| async def _generate( | ||
| async def generate( | ||
| self, | ||
| prompt: PromptType, | ||
| sampling_params: SamplingParams, | ||
|
|
@@ -243,30 +248,6 @@ async def _generate( | |
| await self.abort(request_id) | ||
| raise | ||
|
|
||
| def generate( | ||
| self, | ||
| prompt: PromptType, | ||
| sampling_params: SamplingParams, | ||
| request_id: str, | ||
| lora_request: Optional[LoRARequest] = None, | ||
| trace_headers: Optional[Mapping[str, str]] = None, | ||
| prompt_adapter_request: Optional[PromptAdapterRequest] = None, | ||
| priority: int = 0, | ||
| ) -> AsyncGenerator[RequestOutput, None]: | ||
| kwargs = dict(prompt=prompt, | ||
| sampling_params=sampling_params, | ||
| request_id=request_id, | ||
| lora_request=lora_request, | ||
| trace_headers=trace_headers, | ||
| prompt_adapter_request=prompt_adapter_request, | ||
| priority=priority) | ||
| if sampling_params.n is None or sampling_params.n == 1: | ||
| return self._generate(**kwargs) | ||
| else: | ||
| # Special handling for parallel sampling requests | ||
| return generate_parallel_sampling_async(generate=self._generate, | ||
| **kwargs) | ||
|
|
||
| async def _run_output_handler(self): | ||
| """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.