Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7d25c1a
add plugin via entrypoints
youkaichao Aug 11, 2024
dc33035
add
youkaichao Aug 11, 2024
239e96e
use
youkaichao Aug 11, 2024
be7df65
update
youkaichao Aug 12, 2024
0632fb7
add model
youkaichao Aug 12, 2024
a2e1280
update
youkaichao Aug 12, 2024
a936ebc
use name
youkaichao Aug 12, 2024
5266296
update
youkaichao Aug 12, 2024
efd04f1
update tests
youkaichao Aug 12, 2024
06dddb7
update
youkaichao Aug 12, 2024
01248da
update
youkaichao Aug 12, 2024
3b2adc8
update
youkaichao Aug 12, 2024
93cb2cf
update
youkaichao Aug 12, 2024
b66400a
update
youkaichao Aug 12, 2024
a20a5bf
update server oot test
youkaichao Aug 12, 2024
8b9af15
update dummy args
youkaichao Aug 12, 2024
a560633
re-loadable
youkaichao Aug 12, 2024
eeff503
Merge branch 'main' into entrypoint_plugin
youkaichao Aug 12, 2024
e197d35
skip spell check for dummy model
youkaichao Aug 12, 2024
0074563
Merge branch 'entrypoint_plugin' of github.com:youkaichao/vllm into e…
youkaichao Aug 12, 2024
bd828b7
add vllm/plugins
youkaichao Aug 13, 2024
0b181e4
Merge branch 'main' into entrypoint_plugin
youkaichao Aug 13, 2024
6ba403c
Merge branch 'main' into entrypoint_plugin
youkaichao Aug 13, 2024
ae533fe
add files
youkaichao Aug 13, 2024
d69641a
lint
youkaichao Aug 13, 2024
aa57613
update
youkaichao Aug 13, 2024
5b5b06a
download script every time
youkaichao Aug 13, 2024
63193e8
remove model files
youkaichao Aug 13, 2024
7f64ac2
avoid too many downloads
youkaichao Aug 13, 2024
4457c32
Merge branch 'main' into entrypoint_plugin
youkaichao Aug 13, 2024
f9cd434
lint
youkaichao Aug 13, 2024
b03ab18
change to VLLM_PLUGINS
youkaichao Aug 13, 2024
d524219
revert changes
youkaichao Aug 13, 2024
2f07344
add tests
youkaichao Aug 13, 2024
f8797e5
add distributed tests
youkaichao Aug 13, 2024
b3099c2
add distributed tests
youkaichao Aug 13, 2024
3ad4c8e
add distributed tests
youkaichao Aug 13, 2024
18d2724
Merge branch 'main' into entrypoint_plugin
youkaichao Aug 13, 2024
27b77d8
add distributed tests
youkaichao Aug 13, 2024
e9471d2
Merge branch 'entrypoint_plugin' of github.com:youkaichao/vllm into e…
youkaichao Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ steps:
- pytest -v -s core

- label: Entrypoints Test # 20min
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/openai

Expand Down Expand Up @@ -154,6 +156,7 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models -m \"not vlm\"

- label: Vision Language Models Test # 42min
Expand Down Expand Up @@ -289,6 +292,7 @@ steps:
- pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
compressed-tensors == 0.5.0
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextlib
import gc
import json
import os
import sys
import tempfile
from collections import UserList
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
Expand All @@ -11,6 +13,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
Expand Down Expand Up @@ -757,3 +760,26 @@ def num_gpus_available():
in current process."""

return cuda_device_count_stateless()


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
json_path = os.path.join(_dummy_path, "config.json")
if not os.path.exists(_dummy_path):
snapshot_download(repo_id="facebook/opt-125m",
local_dir=_dummy_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path
6 changes: 6 additions & 0 deletions tests/distributed/test_distributed_oot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..entrypoints.openai.test_oot_registration import (
run_and_test_dummy_opt_api_server)


def test_distributed_oot(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)
106 changes: 27 additions & 79 deletions tests/entrypoints/openai/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,42 @@
import sys
import time

import torch
from openai import OpenAI, OpenAIError

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port

from ...utils import VLLM_PATH, RemoteOpenAIServer

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)

sys.argv = ["placeholder.py"] + [
"--model",
"facebook/opt-125m",
def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
]

import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')


def test_oot_registration_for_api_server():
port = get_open_port()
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()

try:
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
with RemoteOpenAIServer(model, server_args) as server:
client = server.get_client()
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
now = time.time()
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
msg = "Server did not start in time"
raise RuntimeError(msg) from e
else:
raise e
finally:
server.terminate()
generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""


generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
# TODO(youkaichao): Fix the test with plugin
rest = generated_text.replace("<s>", "") # noqa
# assert rest == ""
def test_oot_registration_for_api_server(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path)
35 changes: 15 additions & 20 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
from typing import Optional
import os

import torch
import pytest

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm import LLM, SamplingParams

# NOTE: the order of the tests is important
# the first test does not load any plugins
# the second test loads the plugin
# they share the same process, so the plugin is loaded for the second test

class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)


def test_oot_registration():
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
def test_oot_registration(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m")
llm = LLM(model=dummy_opt_path, load_format="dummy")
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

Expand Down
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_model/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_model',
version='0.1',
packages=['vllm_add_dummy_model'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_model:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

import torch

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits


def register():
# register our dummy model
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def __init__(
)
# TODO(woosuk): Print more configs in debug mode.

from vllm.plugins import load_general_plugins
load_general_plugins()

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down
10 changes: 9 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
Expand Down Expand Up @@ -55,6 +55,7 @@
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_PLUGINS: Optional[List[str]] = None


def get_default_cache_root():
Expand Down Expand Up @@ -362,6 +363,13 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),

# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
}

# end-env-vars-definition
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def resolve_model_cls(

@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())

@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
Expand Down
31 changes: 31 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging

import vllm.envs as envs

logger = logging.getLogger(__name__)


def load_general_plugins():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

allowed_plugins = envs.VLLM_PLUGINS

discovered_plugins = entry_points(group='vllm.general_plugins')
for plugin in discovered_plugins:
logger.info("Found general plugin: %s", plugin.name)
if allowed_plugins is None or plugin.name in allowed_plugins:
try:
func = plugin.load()
func()
logger.info("Loaded general plugin: %s", plugin.name)
except Exception:
logger.exception("Failed to load general plugin: %s",
plugin.name)
3 changes: 3 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def init_worker(self, *args, **kwargs):
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

from vllm.plugins import load_general_plugins
load_general_plugins()

if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
Expand Down