Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def forward(
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output_dtype = (query.dtype if fp8_out_scale is None else
current_platform.fp8_dtype())
output = torch.empty(output_shape,
dtype=query.dtype,
dtype=output_dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
Expand Down
83 changes: 45 additions & 38 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,43 @@ def cdiv_fn(x, y):

@triton.jit
def kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_FP8: tl.constexpr,
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
Expand Down Expand Up @@ -192,6 +194,8 @@ def kernel_paged_attention_2d(

# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc / tl.load(out_scale)

output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)
Expand Down Expand Up @@ -222,8 +226,8 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
fp8_out_scale=None,
):

if sm_scale is None:
sm_scale = 1.0 / (query.shape[1]**0.5)

Expand Down Expand Up @@ -252,6 +256,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=fp8_out_scale,
)

block_size = value_cache.shape[3]
Expand Down Expand Up @@ -293,7 +298,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
Expand Down Expand Up @@ -322,7 +327,7 @@ def chunked_prefill_paged_decode(
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=None,
fp8_out_scale=fp8_out_scale,
)
else:
kernel_paged_attention_2d[(
Expand All @@ -339,6 +344,7 @@ def chunked_prefill_paged_decode(
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale=fp8_out_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
Expand All @@ -364,4 +370,5 @@ def chunked_prefill_paged_decode(
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_FP8=fp8_out_scale is not None,
)
9 changes: 8 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _fwd_kernel(
sm_scale,
k_scale,
v_scale,
out_scale,
B_Start_Loc,
B_Seqlen,
block_size,
Expand Down Expand Up @@ -65,6 +66,7 @@ def _fwd_kernel(
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_FP8: tl.constexpr,
):

cur_batch = tl.program_id(0)
Expand Down Expand Up @@ -263,6 +265,8 @@ def _fwd_kernel(
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc / tl.load(out_scale)
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
Expand Down Expand Up @@ -732,7 +736,8 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
skip_decode=False,
fp8_out_scale=None):

q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
Expand Down Expand Up @@ -852,6 +857,7 @@ def context_attention_fwd(q,
sm_scale,
k_scale,
v_scale,
fp8_out_scale,
b_start_loc,
b_seq_len,
v_cache.shape[3],
Expand Down Expand Up @@ -890,6 +896,7 @@ def context_attention_fwd(q,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
num_warps=NUM_WARPS,
num_stages=1,
)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def __init__(self,
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config.is_fp8_w8a8())
self.attn_fp8_out = (not envs.VLLM_USE_V1
and envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT
self.attn_fp8_out = (envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT
and current_platform.is_fp8_fnuz() and use_fp8)

self.attn = Attention(
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def forward(
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
sm_scale=self.scale,
fp8_out_scale=fp8_out_scale,
)

return output