Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tests/entrypoints/llm/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
EXPECTED_VALUE = 0.58


def run_test():
def run_test(more_args=None):
"""Run the end to end accuracy test."""

model_args = f"pretrained={MODEL_NAME},max_model_len=2048"
model_args = f"pretrained={MODEL_NAME},max_model_len=4096"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why do we need to increase the model len?

Copy link
Collaborator Author

@alexm-redhat alexm-redhat Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some compilation related issue with 2048 that I was not able to fully resolve. I will dig more into this in the next PR.


if more_args is not None:
model_args = "{},{}".format(model_args, more_args)

results = lm_eval.simple_evaluate(
model="vllm",
Expand All @@ -39,14 +42,21 @@ def run_test():
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 is currently only supported on CUDA.")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test()

more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
more_args = "max_num_seqs=64"

run_test(more_args)


def test_lm_eval_accuracy_v0_engine(monkeypatch):
Expand Down
15 changes: 11 additions & 4 deletions tests/entrypoints/openai/correctness/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
Expand Down Expand Up @@ -67,14 +67,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
more_args = []

# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]

run_test(more_args)


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class _Backend(enum.Enum):
TRITON_MLA = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down
54 changes: 38 additions & 16 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend
Expand Down Expand Up @@ -33,22 +34,28 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.PALLAS:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
return "tpu"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
return not envs.VLLM_USE_V1

@classmethod
def inference_mode(cls):
Expand All @@ -63,22 +70,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION

# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

if compilation_config.backend == "":
compilation_config.backend = "openxla"

assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"

assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
Expand All @@ -88,8 +91,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"

# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False

assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
Loading