From 1449a8c41f6f267e5fbaaefb255464954635f707 Mon Sep 17 00:00:00 2001 From: David Xia Date: Wed, 30 Apr 2025 09:24:26 -0400 Subject: [PATCH] fix: `vllm serve` on Apple silicon Right now commands like `vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0` on Apple silicon fail with triton errors like these. ``` $ vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 INFO 04-30 09:33:49 [importing.py:17] Triton not installed or not compatible; certain GPU-related functions will not be available. WARNING 04-30 09:33:49 [importing.py:28] Triton is not installed. Using dummy decorators. Install it via `pip install triton` to enable kernelcompilation. INFO 04-30 09:33:49 [importing.py:53] Triton module has been replaced with a placeholder. INFO 04-30 09:33:50 [__init__.py:239] Automatically detected platform cpu. Traceback (most recent call last): File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/bin/vllm", line 5, in from vllm.entrypoints.cli.main import main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/main.py", line 7, in import vllm.entrypoints.cli.benchmark.main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/benchmark/main.py", line 6, in import vllm.entrypoints.cli.benchmark.throughput File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/benchmark/throughput.py", line 4, in from vllm.benchmarks.throughput import add_cli_args, main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/benchmarks/throughput.py", line 18, in from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/benchmarks/datasets.py", line 34, in from vllm.lora.utils import get_adapter_absolute_path File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/utils.py", line 15, in from vllm.lora.fully_sharded_layers import ( File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/fully_sharded_layers.py", line 14, in from vllm.lora.layers import (ColumnParallelLinearWithLoRA, File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/layers.py", line 29, in from vllm.model_executor.layers.logits_processor import LogitsProcessor File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/model_executor/layers/logits_processor.py", line 13, in from vllm.model_executor.layers.vocab_parallel_embedding import ( File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/model_executor/layers/vocab_parallel_embedding.py", line 139, in @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2543, in fn return compile( ^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2572, in compile return torch._dynamo.optimize( ^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 944, in optimize return _optimize(rebuild_ctx, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 998, in _optimize backend = get_compiler_fn(backend) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 878, in get_compiler_fn from .repro.after_dynamo import wrap_backend_debug File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 35, in from torch._dynamo.debug_utils import ( File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/debug_utils.py", line 44, in from torch._dynamo.testing import rand_strided File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/testing.py", line 33, in from torch._dynamo.backends.debugging import aot_eager File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py", line 35, in from functorch.compile import min_cut_rematerialization_partition File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/functorch/compile/__init__.py", line 2, in from torch._functorch.aot_autograd import ( File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 26, in from torch._inductor.output_code import OutputCode File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 52, in from .runtime.autotune_cache import AutotuneCacheBundler File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py", line 23, in from .triton_compat import Config File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py", line 16, in from triton import Config ImportError: cannot import name 'Config' from 'triton' (unknown location) ``` We cannot install `triton` on Apple silicon because there are no [available distributions][1]. This change adds more placeholders for triton modules and classes that are imported when calling `vllm serve`. [1]: https://pypi.org/project/triton/#files Signed-off-by: David Xia --- vllm/triton_utils/importing.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index fa29efbf6b2d..1cbb1f86d3ff 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -2,6 +2,7 @@ import sys import types +from abc import ABC from importlib.util import find_spec from vllm.logger import init_logger @@ -25,6 +26,8 @@ def __init__(self): self.autotune = self._dummy_decorator("autotune") self.heuristics = self._dummy_decorator("heuristics") self.language = TritonLanguagePlaceholder() + self.Config = self._dummy_decorator("Config") + self.__version__ = "" logger.warning_once( "Triton is not installed. Using dummy decorators. " "Install it via `pip install triton` to enable kernel" @@ -43,11 +46,36 @@ class TritonLanguagePlaceholder(types.ModuleType): def __init__(self): super().__init__("triton.language") - self.constexpr = None + self.constexpr = lambda x: x self.dtype = None + self.extra = None + self.math = None + self.tensor = None + + class TritonCompilerPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.compiler") + self.CompiledKernel = ABC + + class TritonRuntimeAutotunerPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.runtime.autotuner") + self.OutOfResources = ABC + + class TritonRuntimeJitPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.runtime.jit") + self.KernelInterface = ABC sys.modules['triton'] = TritonPlaceholder() sys.modules['triton.language'] = TritonLanguagePlaceholder() + sys.modules['triton.compiler'] = TritonCompilerPlaceholder() + sys.modules[ + 'triton.runtime.autotuner'] = TritonRuntimeAutotunerPlaceholder() + sys.modules['triton.runtime.jit'] = TritonRuntimeJitPlaceholder() if 'triton' in sys.modules: logger.info("Triton module has been replaced with a placeholder.")