diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b9d46450d2a..ff2f0387bf5d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -77,11 +77,13 @@ steps: - pytest -v -s core - label: Entrypoints Test # 20min + working_dir: "/vllm-workspace/tests" fast_check: true mirror_hardwares: [amd] source_file_dependencies: - vllm/ commands: + - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai @@ -154,6 +156,7 @@ steps: - vllm/ - tests/models commands: + - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s models -m \"not vlm\" - label: Vision Language Models Test # 42min @@ -289,6 +292,7 @@ steps: - pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py + - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py diff --git a/requirements-common.txt b/requirements-common.txt index 2f006c887dab..67de33c57873 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -23,4 +23,5 @@ pyzmq librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 +importlib_metadata compressed-tensors == 0.5.0 diff --git a/tests/conftest.py b/tests/conftest.py index ba764223a29e..6e033e76964b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import contextlib import gc +import json import os import sys +import tempfile from collections import UserList from enum import Enum from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, @@ -11,6 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from huggingface_hub import snapshot_download from PIL import Image from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, @@ -757,3 +760,26 @@ def num_gpus_available(): in current process.""" return cuda_device_count_stateless() + + +temp_dir = tempfile.gettempdir() +_dummy_path = os.path.join(temp_dir, "dummy_opt") + + +@pytest.fixture +def dummy_opt_path(): + json_path = os.path.join(_dummy_path, "config.json") + if not os.path.exists(_dummy_path): + snapshot_download(repo_id="facebook/opt-125m", + local_dir=_dummy_path, + ignore_patterns=[ + "*.bin", "*.bin.index.json", "*.pt", "*.h5", + "*.msgpack" + ]) + assert os.path.exists(json_path) + with open(json_path, "r") as f: + config = json.load(f) + config["architectures"] = ["MyOPTForCausalLM"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_path diff --git a/tests/distributed/test_distributed_oot.py b/tests/distributed/test_distributed_oot.py new file mode 100644 index 000000000000..62e77a2f7759 --- /dev/null +++ b/tests/distributed/test_distributed_oot.py @@ -0,0 +1,6 @@ +from ..entrypoints.openai.test_oot_registration import ( + run_and_test_dummy_opt_api_server) + + +def test_distributed_oot(dummy_opt_path: str): + run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index de72fb79b7d4..b25cb1d0e722 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -1,94 +1,42 @@ -import sys -import time - -import torch -from openai import OpenAI, OpenAIError - -from vllm import ModelRegistry -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.utils import get_open_port - from ...utils import VLLM_PATH, RemoteOpenAIServer chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() -class MyOPTForCausalLM(OPTForCausalLM): - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) - logits.zero_() - logits[:, 0] += 1.0 - return logits - - -def server_function(port: int): - # register our dummy model - ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) - - sys.argv = ["placeholder.py"] + [ - "--model", - "facebook/opt-125m", +def run_and_test_dummy_opt_api_server(model, tp=1): + # the model is registered through the plugin + server_args = [ "--gpu-memory-utilization", "0.10", "--dtype", "float32", - "--api-key", - "token-abc123", - "--port", - str(port), "--chat-template", str(chatml_jinja_path), + "--load-format", + "dummy", + "-tp", + f"{tp}", ] - - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') - - -def test_oot_registration_for_api_server(): - port = get_open_port() - ctx = torch.multiprocessing.get_context() - server = ctx.Process(target=server_function, args=(port, )) - server.start() - - try: - client = OpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", + with RemoteOpenAIServer(model, server_args) as server: + client = server.get_client() + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, ) - now = time.time() - while True: - try: - completion = client.chat.completions.create( - model="facebook/opt-125m", - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], - temperature=0, - ) - break - except OpenAIError as e: - if "Connection error" in str(e): - time.sleep(3) - if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S: - msg = "Server did not start in time" - raise RuntimeError(msg) from e - else: - raise e - finally: - server.terminate() + generated_text = completion.choices[0].message.content + assert generated_text is not None + # make sure only the first token is generated + rest = generated_text.replace("", "") + assert rest == "" + - generated_text = completion.choices[0].message.content - assert generated_text is not None - # make sure only the first token is generated - # TODO(youkaichao): Fix the test with plugin - rest = generated_text.replace("", "") # noqa - # assert rest == "" +def test_oot_registration_for_api_server(dummy_opt_path: str): + run_and_test_dummy_opt_api_server(dummy_opt_path) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 4918593ff0f9..3eae23efb285 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,32 +1,27 @@ -from typing import Optional +import os -import torch +import pytest -from vllm import LLM, ModelRegistry, SamplingParams -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm import LLM, SamplingParams +# NOTE: the order of the tests is important +# the first test does not load any plugins +# the second test loads the plugin +# they share the same process, so the plugin is loaded for the second test -class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) - logits.zero_() - logits[:, 0] += 1.0 - return logits +def test_plugin(dummy_opt_path): + os.environ["VLLM_PLUGINS"] = "" + with pytest.raises(Exception) as excinfo: + LLM(model=dummy_opt_path, load_format="dummy") + assert "are not supported for now" in str(excinfo.value) -def test_oot_registration(): - # register our dummy model - ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) +def test_oot_registration(dummy_opt_path): + os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = ["Hello, my name is", "The text does not matter"] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="facebook/opt-125m") + llm = LLM(model=dummy_opt_path, load_format="dummy") first_token = llm.get_tokenizer().decode(0) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/plugins/vllm_add_dummy_model/setup.py b/tests/plugins/vllm_add_dummy_model/setup.py new file mode 100644 index 000000000000..9b535127f1df --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup(name='vllm_add_dummy_model', + version='0.1', + packages=['vllm_add_dummy_model'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_model:register"] + }) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py new file mode 100644 index 000000000000..dcc0305e657a --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -0,0 +1,26 @@ +from typing import Optional + +import torch + +from vllm import ModelRegistry +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + if logits is not None: + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def register(): + # register our dummy model + if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1191d0c66044..a25d60bc0aa3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -227,6 +227,9 @@ def __init__( ) # TODO(woosuk): Print more configs in debug mode. + from vllm.plugins import load_general_plugins + load_general_plugins() + self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config diff --git a/vllm/envs.py b/vllm/envs.py index ca8ec96d07aa..22b2aa37a925 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional if TYPE_CHECKING: VLLM_HOST_IP: str = "" @@ -55,6 +55,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False + VLLM_PLUGINS: Optional[List[str]] = None def get_default_cache_root(): @@ -362,6 +363,13 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in ("1", "true")), + + # a list of plugin names to load, separated by commas. + # if this is not set, it means all plugins will be loaded + # if this is set to an empty string, no plugins will be loaded + "VLLM_PLUGINS": + lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ + "VLLM_PLUGINS"].split(","), } # end-env-vars-definition diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 0f91b92665c2..46aa62e24e8a 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -166,7 +166,7 @@ def resolve_model_cls( @staticmethod def get_supported_archs() -> List[str]: - return list(_MODELS.keys()) + return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) @staticmethod def register_model(model_arch: str, model_cls: Type[nn.Module]): diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py new file mode 100644 index 000000000000..765f74fe7356 --- /dev/null +++ b/vllm/plugins/__init__.py @@ -0,0 +1,31 @@ +import logging + +import vllm.envs as envs + +logger = logging.getLogger(__name__) + + +def load_general_plugins(): + """WARNING: plugins can be loaded for multiple times in different + processes. They should be designed in a way that they can be loaded + multiple times without causing issues. + """ + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group='vllm.general_plugins') + for plugin in discovered_plugins: + logger.info("Found general plugin: %s", plugin.name) + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + func() + logger.info("Loaded general plugin: %s", plugin.name) + except Exception: + logger.exception("Failed to load general plugin: %s", + plugin.name) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 85ab0d348e03..905052d1a951 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -411,6 +411,9 @@ def init_worker(self, *args, **kwargs): # see https://github.com/NVIDIA/nccl/issues/1234 os.environ['NCCL_CUMEM_ENABLE'] = '0' + from vllm.plugins import load_general_plugins + load_general_plugins() + if self.worker_class_fn: worker_class = self.worker_class_fn() else: