Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,6 @@ async def test_new_requests_event():
assert engine.engine.step_calls == old_step_calls + 1

engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
157 changes: 157 additions & 0 deletions tests/entrypoints/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# imports for guided decoding tests
import os
import subprocess
import sys
import time

import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"


@ray.remote(num_gpus=1)
class ServerRunner:

def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()

def ready(self):
return True

def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err

time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err

def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()


@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote([
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()


@pytest.fixture(scope="session")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client


@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)


@pytest.mark.asyncio
async def test_single_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)

# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5


@pytest.mark.asyncio
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

# test single completion
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})

# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
10 changes: 9 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.config import DecodingConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
Expand Down Expand Up @@ -697,6 +697,14 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config()

async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote() # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def create_chat_completion(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def create_completion(self, request: CompletionRequest,
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
Expand Down