@@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
12871287 // max_num_partitions, head_size]
12881288 const int * __restrict__ context_lens, // [num_seqs]
12891289 const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1290- const int max_num_partitions) {
1290+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
12911291 const auto num_heads = gridDim .x ;
12921292 const auto head_idx = blockIdx .x ;
12931293 const auto seq_idx = blockIdx .y ;
@@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
14651465
14661466 const float inv_global_exp_sum =
14671467 __fdividef (1 .0f , shared_global_exp_sum + 1e-6f );
1468+ const float out_scale =
1469+ (fp8_out_scale_ptr != nullptr ) ? 1 .0f / (*fp8_out_scale_ptr) : 1 .0f ;
14681470 acc *= inv_global_exp_sum;
1469-
1471+ acc *= out_scale;
14701472 const int64_t query_start_off = static_cast <int64_t >(
14711473 query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
14721474 OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15481550 const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
15491551 const int * __restrict__ context_lens, // [num_seqs]
15501552 const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1551- const int max_num_partitions) {
1553+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
15521554 UNREACHABLE_CODE
15531555}
15541556// clang-format on
@@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15821584 PARTITION_SIZE, NPAR_LOOPS> \
15831585 <<<reduce_grid, reduce_block, 0 , stream>>> ( \
15841586 out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
1585- context_lens_ptr, query_start_loc_ptr, max_num_partitions);
1587+ context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
1588+ fp8_out_scale_ptr);
15861589
15871590template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
15881591 int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
15941597 torch::Tensor& block_tables, torch::Tensor& context_lens,
15951598 const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
15961599 const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
1597- torch::Tensor& v_scale) {
1600+ torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale ) {
15981601 int num_seqs = block_tables.size (0 );
15991602 int num_heads = query.size (1 );
16001603 int head_size = query.size (2 );
@@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
16261629 int * context_lens_ptr = context_lens.data_ptr <int >();
16271630 const float * k_scale_ptr = reinterpret_cast <const float *>(k_scale.data_ptr ());
16281631 const float * v_scale_ptr = reinterpret_cast <const float *>(v_scale.data_ptr ());
1632+ // NOTE: fp8_out_scale is optional.
1633+ const auto fp8_out_scale_ptr =
1634+ fp8_out_scale
1635+ ? static_cast <const float *>(fp8_out_scale.value ().data_ptr ())
1636+ : nullptr ;
16291637 OUTT* out_ptr = reinterpret_cast <OUTT*>(out.data_ptr ());
16301638
16311639 const int max_ctx_blocks = DIVIDE_ROUND_UP (max_context_len, BLOCK_SIZE);
@@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher(
17361744 }
17371745}
17381746
1739- #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
1740- ALIBI_ENABLED) \
1741- paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1742- PSIZE, ALIBI_ENABLED>( \
1743- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1744- num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1745- max_context_len, alibi_slopes, k_scale, v_scale);
1746-
1747- #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1748- PSIZE) \
1749- if (alibi_slopes) { \
1750- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true ); \
1751- } else { \
1752- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false ); \
1747+ #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1748+ PSIZE, ALIBI_ENABLED) \
1749+ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1750+ PSIZE, ALIBI_ENABLED>( \
1751+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1752+ num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1753+ max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
1754+
1755+ #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1756+ OUTT, PSIZE) \
1757+ if (alibi_slopes) { \
1758+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1759+ true ); \
1760+ } else { \
1761+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1762+ false ); \
17531763 }
17541764
1755- #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1756- switch (block_size) { \
1757- case 16 : \
1758- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 16 , HEAD_SIZE, 256 ); \
1759- break ; \
1760- case 32 : \
1761- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 32 , HEAD_SIZE, 256 ); \
1762- break ; \
1763- default : \
1764- TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1765- break ; \
1765+ #if defined(__HIPCC__) && defined(__gfx90a__)
1766+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1767+ if (fp8_out_scale) { \
1768+ TORCH_CHECK (false , " fp8 out scale unsupported for gfx90a" ); \
1769+ } else { \
1770+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1771+ 256 ); \
1772+ }
1773+ #else
1774+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1775+ if (fp8_out_scale) { \
1776+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1777+ uint8_t , 256 ); \
1778+ } else { \
1779+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1780+ 256 ); \
1781+ }
1782+ #endif
1783+
1784+ #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1785+ switch (block_size) { \
1786+ case 16 : \
1787+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 16 , HEAD_SIZE); \
1788+ break ; \
1789+ case 32 : \
1790+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 32 , HEAD_SIZE); \
1791+ break ; \
1792+ default : \
1793+ TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1794+ break ; \
17661795 }
17671796
17681797#define CALL_CUSTOM_LAUNCHER_BLK_HEAD (T, KVT, KV_DTYPE ) \
@@ -1795,7 +1824,8 @@ void paged_attention(
17951824 int64_t block_size, int64_t max_context_len,
17961825 const std::optional<torch::Tensor>& alibi_slopes,
17971826 const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1798- torch::Tensor& v_scale) {
1827+ torch::Tensor& v_scale,
1828+ const c10::optional<torch::Tensor>& fp8_out_scale) {
17991829 // clang-format on
18001830 const int head_size = query.size (2 );
18011831 if (kv_cache_dtype == " auto" ) {
0 commit comments