Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/lora/test_add_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from huggingface_hub import snapshot_download

import vllm.envs as env
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -144,10 +145,14 @@ async def test_add_lora():
await requests_processing_time(llm, dummy_run_requests)

# Run with warmup
for lr in warmup_run_requests:
await llm.add_lora(lr)
# Wait for the add_lora function to complete on the server side.
await asyncio.sleep(30)
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
add_lora_results = await asyncio.gather(*add_lora_tasks)
if env.VLLM_USE_V1:
# Test that all all_lora calls are successful.
assert all(add_lora_results)
else:
# No way to check V0 engine results as the calls just return None.
pass
time_with_add_lora = await requests_processing_time(
llm, warmup_run_requests)

Expand Down
137 changes: 137 additions & 0 deletions tests/lora/test_lora_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
"""
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
"""

import os
from typing import List

import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.lora.request import LoRARequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
LORA_RANK = 8


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def make_lora_request(lora_id: int):
return LoRARequest(lora_name=f"{lora_id}",
lora_int_id=lora_id,
lora_path=LORA_MODULE_PATH)


def test_lora_functions_sync():

max_loras = 4
# Create engine in eager-mode. Due to high max_loras, the CI can
# OOM during cuda-graph capture.
engine_args = EngineArgs(model=MODEL_PATH,
enable_lora=True,
max_loras=max_loras,
max_lora_rank=LORA_RANK,
max_model_len=128,
gpu_memory_utilization=0.8,
enforce_eager=True)

llm = LLM.get_engine_class().from_engine_args(engine_args)

def run_check(fn, args, expected: List):
fn(args)
assert set(llm.list_loras()) == set(expected)

run_check(llm.add_lora, make_lora_request(1), [1])
run_check(llm.add_lora, make_lora_request(2), [1, 2])

# Pin LoRA 1 and test that it is never removed on subsequent adds.
run_check(llm.pin_lora, 1, [1, 2])
run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])

# Remove LoRA 1 and continue adding.
run_check(llm.remove_lora, 1, [8, 9, 10])
run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])

# Remove all LoRAs
run_check(llm.remove_lora, 13, [12, 10, 11])
run_check(llm.remove_lora, 12, [10, 11])
run_check(llm.remove_lora, 11, [10])
run_check(llm.remove_lora, 10, [])


@pytest.mark.asyncio
async def test_lora_functions_async():

if os.getenv("VLLM_USE_V1") == "0":
pytest.skip(
reason=
"V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions")

# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
# environment variable. reload vllm.enging.async_llm_engine as
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
# env var.
import importlib

import vllm.engine.async_llm_engine
importlib.reload(vllm.engine.async_llm_engine)
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)

max_loras = 4
engine_args = AsyncEngineArgs(model=MODEL_PATH,
enable_lora=True,
max_loras=max_loras,
max_lora_rank=LORA_RANK,
max_model_len=128,
gpu_memory_utilization=0.8,
enforce_eager=True)

async def run_check(fn, args, expected: List):
await fn(args)
assert set(await llm.list_loras()) == set(expected)

async with build_async_engine_client_from_engine_args(engine_args) as llm:
await run_check(llm.add_lora, make_lora_request(1), [1])
await run_check(llm.add_lora, make_lora_request(2), [1, 2])

# Pin LoRA 1 and test that it is never removed on subsequent adds.
await run_check(llm.pin_lora, 1, [1, 2])
await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])

# Remove LoRA 1 and continue adding.
await run_check(llm.remove_lora, 1, [8, 9, 10])
await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])

# Remove all LoRAs
await run_check(llm.remove_lora, 13, [12, 10, 11])
await run_check(llm.remove_lora, 12, [10, 11])
await run_check(llm.remove_lora, 11, [10])
await run_check(llm.remove_lora, 10, [])
18 changes: 15 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import os
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union

import numpy as np

Expand Down Expand Up @@ -367,9 +367,21 @@ async def sleep(self, level: int = 1) -> None:
async def wake_up(self) -> None:
await self.engine_core.wake_up_async()

async def add_lora(self, lora_request: LoRARequest) -> None:
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
await self.engine_core.add_lora_async(lora_request)
return await self.engine_core.add_lora_async(lora_request)

async def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)

async def list_loras(self) -> Set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()

async def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return await self.engine_core.pin_lora_async(lora_id)

@property
def is_running(self) -> bool:
Expand Down
15 changes: 12 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Type
from typing import Any, List, Optional, Set, Tuple, Type

import msgspec
import psutil
Expand Down Expand Up @@ -222,8 +222,17 @@ def wake_up(self):
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")

def add_lora(self, lora_request: LoRARequest) -> None:
self.model_executor.add_lora(lora_request)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.model_executor.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)


class EngineCoreProc(EngineCore):
Expand Down
63 changes: 54 additions & 9 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Set, Type, Union

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -96,7 +96,16 @@ async def execute_dummy_batch_async(self) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

def add_lora(self, lora_request: LoRARequest) -> None:
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

def list_loras(self) -> Set[int]:
raise NotImplementedError

def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

async def get_output_async(self) -> EngineCoreOutputs:
Expand All @@ -120,7 +129,16 @@ async def wake_up_async(self) -> None:
async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError

async def add_lora_async(self, lora_request: LoRARequest) -> None:
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

async def remove_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError

async def list_loras_async(self) -> Set[int]:
raise NotImplementedError

async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError


Expand Down Expand Up @@ -165,8 +183,17 @@ def wake_up(self) -> None:
def execute_dummy_batch(self) -> None:
self.engine_core.execute_dummy_batch()

def add_lora(self, lora_request: LoRARequest) -> None:
self.engine_core.add_lora(lora_request)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.engine_core.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
return self.engine_core.remove_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.engine_core.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id)


class MPClient(EngineCoreClient):
Expand Down Expand Up @@ -331,8 +358,17 @@ def profile(self, is_start: bool = True) -> None:
def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache")

def add_lora(self, lora_request: LoRARequest) -> None:
self._call_utility("add_lora", lora_request)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request)

def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id)

def list_loras(self) -> Set[int]:
return self._call_utility("list_loras")

def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id)

def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level)
Expand Down Expand Up @@ -429,5 +465,14 @@ async def wake_up_async(self) -> None:
async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch")

async def add_lora_async(self, lora_request: LoRARequest) -> None:
await self._call_utility_async("add_lora", lora_request)
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request)

async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id)

async def list_loras_async(self) -> Set[int]:
return await self._call_utility_async("list_loras")

async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id)
18 changes: 17 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Mapping, Optional, Type, Union
from typing import Dict, List, Mapping, Optional, Set, Type, Union

from typing_extensions import TypeVar

Expand Down Expand Up @@ -217,3 +217,19 @@ def get_tokenizer_group(
f"found type: {type(tokenizer_group)}")

return tokenizer_group

def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)

def list_loras(self) -> Set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()

def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
11 changes: 10 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Set

import torch
import torch.distributed
Expand Down Expand Up @@ -240,6 +240,15 @@ def execute_dummy_batch(self) -> None:
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
Expand Down
Loading