Skip to content

Commit 732455b

Browse files
authored
Fused FP8 conversion in attention for v1 (#502)
* Enable fused fp8 out in V1 CPA and FA * Correct operation and creating the tensot or th correct type * Update to use for the non-custom path as well * This was a debug assert
1 parent 6d258fa commit 732455b

File tree

5 files changed

+60
-43
lines changed

5 files changed

+60
-43
lines changed

vllm/attention/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,10 @@ def forward(
186186
if self.use_output:
187187
output_shape = (output_shape
188188
if output_shape is not None else query.shape)
189+
output_dtype = (query.dtype if fp8_out_scale is None else
190+
current_platform.fp8_dtype())
189191
output = torch.empty(output_shape,
190-
dtype=query.dtype,
192+
dtype=output_dtype,
191193
device=query.device)
192194
hidden_size = output_shape[-1]
193195
# We skip reshaping query, key and value tensors for the MLA

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,41 +23,43 @@ def cdiv_fn(x, y):
2323

2424
@triton.jit
2525
def kernel_paged_attention_2d(
26-
output_ptr, # [num_tokens, num_query_heads, head_size]
27-
query_ptr, # [num_tokens, num_query_heads, head_size]
28-
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
29-
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
30-
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
31-
seq_lens_ptr, # [num_seqs]
32-
alibi_slopes_ptr, # [num_query_heads]
33-
scale, # float32
34-
k_scale, # float32
35-
v_scale, # float32
36-
num_query_heads: tl.constexpr, # int
37-
num_queries_per_kv: tl.constexpr, # int
38-
num_queries_per_kv_padded: tl.constexpr, # int
39-
block_table_stride: tl.int64, # int
40-
query_stride_0: tl.int64, # int
41-
query_stride_1: tl.int64, # int, should be equal to head_size
42-
output_stride_0: tl.int64, # int
43-
output_stride_1: tl.int64, # int, should be equal to head_size
44-
BLOCK_SIZE: tl.constexpr, # int
45-
HEAD_SIZE: tl.constexpr, # int
46-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
47-
USE_ALIBI_SLOPES: tl.constexpr, # bool
48-
SLIDING_WINDOW: tl.constexpr, # int
49-
x: tl.constexpr, # int
50-
stride_k_cache_0: tl.int64, # int
51-
stride_k_cache_1: tl.int64, # int
52-
stride_k_cache_2: tl.int64, # int
53-
stride_k_cache_3: tl.int64, # int
54-
stride_k_cache_4: tl.int64, # int
55-
stride_v_cache_0: tl.int64, # int
56-
stride_v_cache_1: tl.int64, # int
57-
stride_v_cache_2: tl.int64, # int
58-
stride_v_cache_3: tl.int64, # int
59-
filter_by_query_len: tl.constexpr, # bool
60-
query_start_len_ptr, # [num_seqs+1]
26+
output_ptr, # [num_tokens, num_query_heads, head_size]
27+
query_ptr, # [num_tokens, num_query_heads, head_size]
28+
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
29+
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
30+
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
31+
seq_lens_ptr, # [num_seqs]
32+
alibi_slopes_ptr, # [num_query_heads]
33+
scale, # float32
34+
k_scale, # float32
35+
v_scale, # float32
36+
out_scale,
37+
num_query_heads: tl.constexpr, # int
38+
num_queries_per_kv: tl.constexpr, # int
39+
num_queries_per_kv_padded: tl.constexpr, # int
40+
block_table_stride: tl.int64, # int
41+
query_stride_0: tl.int64, # int
42+
query_stride_1: tl.int64, # int, should be equal to head_size
43+
output_stride_0: tl.int64, # int
44+
output_stride_1: tl.int64, # int, should be equal to head_size
45+
BLOCK_SIZE: tl.constexpr, # int
46+
HEAD_SIZE: tl.constexpr, # int
47+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
48+
USE_ALIBI_SLOPES: tl.constexpr, # bool
49+
SLIDING_WINDOW: tl.constexpr, # int
50+
x: tl.constexpr, # int
51+
stride_k_cache_0: tl.int64, # int
52+
stride_k_cache_1: tl.int64, # int
53+
stride_k_cache_2: tl.int64, # int
54+
stride_k_cache_3: tl.int64, # int
55+
stride_k_cache_4: tl.int64, # int
56+
stride_v_cache_0: tl.int64, # int
57+
stride_v_cache_1: tl.int64, # int
58+
stride_v_cache_2: tl.int64, # int
59+
stride_v_cache_3: tl.int64, # int
60+
filter_by_query_len: tl.constexpr, # bool
61+
query_start_len_ptr, # [num_seqs+1]
62+
USE_FP8: tl.constexpr,
6163
):
6264
seq_idx = tl.program_id(0)
6365
kv_head_idx = tl.program_id(1)
@@ -192,6 +194,8 @@ def kernel_paged_attention_2d(
192194

193195
# epilogue
194196
acc = acc / L[:, None]
197+
if USE_FP8:
198+
acc = acc / tl.load(out_scale)
195199

196200
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
197201
query_head_idx * output_stride_1)
@@ -222,8 +226,8 @@ def chunked_prefill_paged_decode(
222226
alibi_slopes=None,
223227
sliding_window=None,
224228
sm_scale=None,
229+
fp8_out_scale=None,
225230
):
226-
227231
if sm_scale is None:
228232
sm_scale = 1.0 / (query.shape[1]**0.5)
229233

@@ -252,6 +256,7 @@ def chunked_prefill_paged_decode(
252256
sliding_window=sliding_window,
253257
sm_scale=sm_scale,
254258
skip_decode=True,
259+
fp8_out_scale=fp8_out_scale,
255260
)
256261

257262
block_size = value_cache.shape[3]
@@ -293,7 +298,7 @@ def chunked_prefill_paged_decode(
293298
tmp_output = torch.empty(
294299
size=(total_num_seq, num_query_heads, max_num_partitions,
295300
head_size),
296-
dtype=output.dtype,
301+
dtype=query.dtype,
297302
device=output.device,
298303
)
299304
exp_sums = torch.empty(
@@ -322,7 +327,7 @@ def chunked_prefill_paged_decode(
322327
kv_cache_dtype=kv_cache_dtype,
323328
k_scale=k_scale,
324329
v_scale=v_scale,
325-
fp8_out_scale=None,
330+
fp8_out_scale=fp8_out_scale,
326331
)
327332
else:
328333
kernel_paged_attention_2d[(
@@ -339,6 +344,7 @@ def chunked_prefill_paged_decode(
339344
scale=sm_scale,
340345
k_scale=k_scale,
341346
v_scale=v_scale,
347+
out_scale=fp8_out_scale,
342348
num_query_heads=num_query_heads,
343349
num_queries_per_kv=num_queries_per_kv,
344350
num_queries_per_kv_padded=num_queries_per_kv_padded,
@@ -364,4 +370,5 @@ def chunked_prefill_paged_decode(
364370
stride_v_cache_3=value_cache.stride(3),
365371
filter_by_query_len=True,
366372
query_start_len_ptr=query_start_loc,
373+
USE_FP8=fp8_out_scale is not None,
367374
)

vllm/attention/ops/prefix_prefill.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _fwd_kernel(
2929
sm_scale,
3030
k_scale,
3131
v_scale,
32+
out_scale,
3233
B_Start_Loc,
3334
B_Seqlen,
3435
block_size,
@@ -65,6 +66,7 @@ def _fwd_kernel(
6566
BLOCK_N: tl.constexpr,
6667
SLIDING_WINDOW: tl.constexpr,
6768
SKIP_DECODE: tl.constexpr,
69+
USE_FP8: tl.constexpr,
6870
):
6971

7072
cur_batch = tl.program_id(0)
@@ -263,6 +265,8 @@ def _fwd_kernel(
263265
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
264266
cur_head * stride_oh + offs_d[None, :] * stride_od)
265267
out_ptrs = Out + off_o
268+
if USE_FP8:
269+
acc = acc / tl.load(out_scale)
266270
tl.store(out_ptrs,
267271
acc,
268272
mask=dim_mask[None, :] &
@@ -732,7 +736,8 @@ def context_attention_fwd(q,
732736
alibi_slopes=None,
733737
sliding_window=None,
734738
sm_scale=None,
735-
skip_decode=False):
739+
skip_decode=False,
740+
fp8_out_scale=None):
736741

737742
q_dtype_is_f32 = q.dtype is torch.float32
738743
# need to reduce num. blocks when using fp32
@@ -852,6 +857,7 @@ def context_attention_fwd(q,
852857
sm_scale,
853858
k_scale,
854859
v_scale,
860+
fp8_out_scale,
855861
b_start_loc,
856862
b_seq_len,
857863
v_cache.shape[3],
@@ -890,6 +896,7 @@ def context_attention_fwd(q,
890896
BLOCK_N=BLOCK,
891897
SLIDING_WINDOW=sliding_window,
892898
SKIP_DECODE=skip_decode,
899+
USE_FP8=fp8_out_scale is not None,
893900
num_warps=NUM_WARPS,
894901
num_stages=1,
895902
)

vllm/model_executor/models/llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def __init__(self,
205205
use_fp8 = isinstance(
206206
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
207207
and quant_config.is_fp8_w8a8())
208-
self.attn_fp8_out = (not envs.VLLM_USE_V1
209-
and envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT
208+
self.attn_fp8_out = (envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT
210209
and current_platform.is_fp8_fnuz() and use_fp8)
211210

212211
self.attn = Attention(

vllm/v1/attention/backends/triton_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def forward(
175175
v_scale=layer._v_scale,
176176
alibi_slopes=self.alibi_slopes,
177177
sliding_window=self.sliding_window[0],
178-
sm_scale=self.scale)
178+
sm_scale=self.scale,
179+
fp8_out_scale=fp8_out_scale,
180+
)
179181

180182
return output

0 commit comments

Comments
 (0)