Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 21 additions & 40 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
Expand Down Expand Up @@ -145,25 +145,30 @@ async def add_request(
"""Add new request to the AsyncLLM."""

# 1) Create a new output queue for the request.
if self.output_processor.is_request_active(request_id):
raise ValueError(f"Request id {request_id} already running.")
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()

# 2) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 2) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)

# 3) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, queue)
# 3) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
Comment on lines +158 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I wondered - can we reuse the output of process_inputs over multiple child requests? Since I believe process_inputs operates solely on the prompt (or MM inputs or whatever.)

This could be tricky for multimodal so I don't know if it is in-scope for this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did wonder that too, but didn't dig into it. Seem quite orthogonal to this PR though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agree. Actually we can move this outside of the loop I think, and just update the request_id and sampling_params of the request inside the loop.

Doesn't need to hold up this PR but would be quite a simple change I think.


# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
# 4) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, idx, queue)

if self.log_requests:
logger.info("Added request %s.", request_id)
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)

if self.log_requests:
logger.info("Added request %s.", request_id)

return queue

Expand All @@ -172,7 +177,7 @@ async def add_request(
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def _generate(
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
Expand Down Expand Up @@ -243,30 +248,6 @@ async def _generate(
await self.abort(request_id)
raise

def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
kwargs = dict(prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
if sampling_params.n is None or sampling_params.n == 1:
return self._generate(**kwargs)
else:
# Special handling for parallel sampling requests
return generate_parallel_sampling_async(generate=self._generate,
**kwargs)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""

Expand Down
74 changes: 22 additions & 52 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor

Expand Down Expand Up @@ -50,9 +50,6 @@ def __init__(
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config

# Bookkeeping for parallel sampling requests
self.parallel_manager = SyncParallelSamplingManager()

# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
Expand Down Expand Up @@ -120,8 +117,7 @@ def from_engine_args(
multiprocess_mode=enable_multiprocessing)

def get_num_unfinished_requests(self) -> int:
return self.parallel_manager.get_num_unfinished_requests(
self.output_processor.get_num_unfinished_requests())
return self.output_processor.get_num_unfinished_requests()

def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
Expand Down Expand Up @@ -157,48 +153,25 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request."""
kwargs = dict(request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
# Handle parallel sampling requests differently.
if params is None or isinstance(params,
PoolingParams) or params.n == 1:
self._add_request(**kwargs)
else:
# Special handling for parallel sampling requests
self.parallel_manager.add_request_parallel_sampling(
add_request=self._add_request, **kwargs)

def _add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)

# 2) Make a new RequestState and queue.
self.output_processor.add_request(request)

# 3) Add the request to EngineCore.
self.engine_core.add_request(request)
# 1) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)

# 2) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)

# 3) Make a new RequestState and queue.
self.output_processor.add_request(request, parent_req, idx)

# 3) Add the request to EngineCore.
self.engine_core.add_request(request)

def step(self) -> list[RequestOutput]:

Expand All @@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

request_outputs = processed_outputs.request_outputs

# 4) Process unfinished parallel sampling requests
return self.parallel_manager.step(request_outputs)
return processed_outputs.request_outputs

def get_model_config(self):
return self.model_config
Expand Down
Loading