@@ -23,41 +23,43 @@ def cdiv_fn(x, y):
2323
2424@triton .jit
2525def kernel_paged_attention_2d (
26- output_ptr , # [num_tokens, num_query_heads, head_size]
27- query_ptr , # [num_tokens, num_query_heads, head_size]
28- key_cache_ptr , # [num_blks, num_kv_heads, head_size // x, blk_size, x]
29- value_cache_ptr , # [num_blks, num_kv_heads, head_size, blk_size]
30- block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
31- seq_lens_ptr , # [num_seqs]
32- alibi_slopes_ptr , # [num_query_heads]
33- scale , # float32
34- k_scale , # float32
35- v_scale , # float32
36- num_query_heads : tl .constexpr , # int
37- num_queries_per_kv : tl .constexpr , # int
38- num_queries_per_kv_padded : tl .constexpr , # int
39- block_table_stride : tl .int64 , # int
40- query_stride_0 : tl .int64 , # int
41- query_stride_1 : tl .int64 , # int, should be equal to head_size
42- output_stride_0 : tl .int64 , # int
43- output_stride_1 : tl .int64 , # int, should be equal to head_size
44- BLOCK_SIZE : tl .constexpr , # int
45- HEAD_SIZE : tl .constexpr , # int
46- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
47- USE_ALIBI_SLOPES : tl .constexpr , # bool
48- SLIDING_WINDOW : tl .constexpr , # int
49- x : tl .constexpr , # int
50- stride_k_cache_0 : tl .int64 , # int
51- stride_k_cache_1 : tl .int64 , # int
52- stride_k_cache_2 : tl .int64 , # int
53- stride_k_cache_3 : tl .int64 , # int
54- stride_k_cache_4 : tl .int64 , # int
55- stride_v_cache_0 : tl .int64 , # int
56- stride_v_cache_1 : tl .int64 , # int
57- stride_v_cache_2 : tl .int64 , # int
58- stride_v_cache_3 : tl .int64 , # int
59- filter_by_query_len : tl .constexpr , # bool
60- query_start_len_ptr , # [num_seqs+1]
26+ output_ptr , # [num_tokens, num_query_heads, head_size]
27+ query_ptr , # [num_tokens, num_query_heads, head_size]
28+ key_cache_ptr , # [num_blks, num_kv_heads, head_size // x, blk_size, x]
29+ value_cache_ptr , # [num_blks, num_kv_heads, head_size, blk_size]
30+ block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
31+ seq_lens_ptr , # [num_seqs]
32+ alibi_slopes_ptr , # [num_query_heads]
33+ scale , # float32
34+ k_scale , # float32
35+ v_scale , # float32
36+ out_scale ,
37+ num_query_heads : tl .constexpr , # int
38+ num_queries_per_kv : tl .constexpr , # int
39+ num_queries_per_kv_padded : tl .constexpr , # int
40+ block_table_stride : tl .int64 , # int
41+ query_stride_0 : tl .int64 , # int
42+ query_stride_1 : tl .int64 , # int, should be equal to head_size
43+ output_stride_0 : tl .int64 , # int
44+ output_stride_1 : tl .int64 , # int, should be equal to head_size
45+ BLOCK_SIZE : tl .constexpr , # int
46+ HEAD_SIZE : tl .constexpr , # int
47+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
48+ USE_ALIBI_SLOPES : tl .constexpr , # bool
49+ SLIDING_WINDOW : tl .constexpr , # int
50+ x : tl .constexpr , # int
51+ stride_k_cache_0 : tl .int64 , # int
52+ stride_k_cache_1 : tl .int64 , # int
53+ stride_k_cache_2 : tl .int64 , # int
54+ stride_k_cache_3 : tl .int64 , # int
55+ stride_k_cache_4 : tl .int64 , # int
56+ stride_v_cache_0 : tl .int64 , # int
57+ stride_v_cache_1 : tl .int64 , # int
58+ stride_v_cache_2 : tl .int64 , # int
59+ stride_v_cache_3 : tl .int64 , # int
60+ filter_by_query_len : tl .constexpr , # bool
61+ query_start_len_ptr , # [num_seqs+1]
62+ USE_FP8 : tl .constexpr ,
6163):
6264 seq_idx = tl .program_id (0 )
6365 kv_head_idx = tl .program_id (1 )
@@ -192,6 +194,8 @@ def kernel_paged_attention_2d(
192194
193195 # epilogue
194196 acc = acc / L [:, None ]
197+ if USE_FP8 :
198+ acc = acc / tl .load (out_scale )
195199
196200 output_offset = (cur_batch_in_all_start_index * output_stride_0 +
197201 query_head_idx * output_stride_1 )
@@ -222,8 +226,8 @@ def chunked_prefill_paged_decode(
222226 alibi_slopes = None ,
223227 sliding_window = None ,
224228 sm_scale = None ,
229+ fp8_out_scale = None ,
225230):
226-
227231 if sm_scale is None :
228232 sm_scale = 1.0 / (query .shape [1 ]** 0.5 )
229233
@@ -252,6 +256,7 @@ def chunked_prefill_paged_decode(
252256 sliding_window = sliding_window ,
253257 sm_scale = sm_scale ,
254258 skip_decode = True ,
259+ fp8_out_scale = fp8_out_scale ,
255260 )
256261
257262 block_size = value_cache .shape [3 ]
@@ -293,7 +298,7 @@ def chunked_prefill_paged_decode(
293298 tmp_output = torch .empty (
294299 size = (total_num_seq , num_query_heads , max_num_partitions ,
295300 head_size ),
296- dtype = output .dtype ,
301+ dtype = query .dtype ,
297302 device = output .device ,
298303 )
299304 exp_sums = torch .empty (
@@ -322,7 +327,7 @@ def chunked_prefill_paged_decode(
322327 kv_cache_dtype = kv_cache_dtype ,
323328 k_scale = k_scale ,
324329 v_scale = v_scale ,
325- fp8_out_scale = None ,
330+ fp8_out_scale = fp8_out_scale ,
326331 )
327332 else :
328333 kernel_paged_attention_2d [(
@@ -339,6 +344,7 @@ def chunked_prefill_paged_decode(
339344 scale = sm_scale ,
340345 k_scale = k_scale ,
341346 v_scale = v_scale ,
347+ out_scale = fp8_out_scale ,
342348 num_query_heads = num_query_heads ,
343349 num_queries_per_kv = num_queries_per_kv ,
344350 num_queries_per_kv_padded = num_queries_per_kv_padded ,
@@ -364,4 +370,5 @@ def chunked_prefill_paged_decode(
364370 stride_v_cache_3 = value_cache .stride (3 ),
365371 filter_by_query_len = True ,
366372 query_start_len_ptr = query_start_loc ,
373+ USE_FP8 = fp8_out_scale is not None ,
367374 )
0 commit comments