Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -115,13 +111,6 @@ def __init__(
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if tpu_version == 4:
self.vmem_limit_bytes = 16 * 1024 * 1024
else:
self.vmem_limit_bytes = 64 * 1024 * 1024

def forward(
self,
Expand Down Expand Up @@ -165,9 +154,12 @@ def forward(
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
vmem_limit_bytes=self.vmem_limit_bytes,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
Expand Down
8 changes: 2 additions & 6 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
Expand Down Expand Up @@ -155,11 +154,8 @@ def __init__(
dtype=torch.int64,
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()

padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req),
(self.max_num_tokens, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")

Expand Down