diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 278986329802..af729ee9910f 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -11,10 +11,6 @@ AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState -# These are the 2 tunable parameters of the paged attention Pallas kernel. -NUM_QUERIES_PER_BLOCK = 32 -NUM_KV_PAGES_PER_BLOCK = 128 - class PallasAttentionBackend(AttentionBackend): @@ -115,13 +111,6 @@ def __init__( tpu_version = torch_xla.tpu.version() if tpu_version < 4: raise NotImplementedError("TPU version must be 4 or higher.") - # NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB - # TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK, - # NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes - if tpu_version == 4: - self.vmem_limit_bytes = 16 * 1024 * 1024 - else: - self.vmem_limit_bytes = 64 * 1024 * 1024 def forward( self, @@ -165,9 +154,12 @@ def forward( attn_metadata.block_tables, attn_metadata.query_start_loc, attn_metadata.num_seqs, - num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, - num_queries_per_block=NUM_QUERIES_PER_BLOCK, - vmem_limit_bytes=self.vmem_limit_bytes, + # By default, the system utilizes optimized block size and + # vmem_limit_bytes parameters from the kernel repository. However, + # these can be manually adjusted for debugging if necessary. + num_kv_pages_per_block=None, + num_queries_per_block=None, + vmem_limit_bytes=None, use_kernel=True, sm_scale=self.scale, sliding_window=self.sliding_window, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b1d5c0f33854..0668e7168b5f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -24,8 +24,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, - PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -155,11 +154,8 @@ def __init__( dtype=torch.int64, device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - - padded_max_num_blocks_per_req = _get_padded_number( - self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( - (self.max_num_tokens, padded_max_num_blocks_per_req), + (self.max_num_tokens, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu")