diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index bd24072f4c1a..67902b41b284 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,7 +18,7 @@ from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, +from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache @@ -137,7 +137,7 @@ def determine_available_memory(self) -> int: kv_caches: dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() for layer_name, layer_spec in kv_cache_spec.items(): - if isinstance(layer_spec, FullAttentionSpec): + if isinstance(layer_spec, AttentionSpec): dtype = layer_spec.dtype # Use an empty tensor instead of `None`` to force Dynamo to pass @@ -147,7 +147,8 @@ def determine_available_memory(self) -> int: device=self.device) kv_caches[layer_name] = tpu_kv_cache else: - raise NotImplementedError + raise NotImplementedError( + f"Unsupported KV cache spec '{type(layer_spec)}'") runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(