1717 * limitations under the License. 
1818 */  
1919
20- #include  < torch/extension .h> 
20+ #include  < torch/all .h> 
2121#include  < ATen/cuda/CUDAContext.h> 
2222#include  < c10/cuda/CUDAGuard.h> 
2323#include  < algorithm> 
@@ -808,16 +808,17 @@ void paged_attention_v1(
808808    torch::Tensor&
809809        key_cache,  //  [num_blocks, num_heads, head_size/x, block_size, x]
810810    torch::Tensor&
811-         value_cache,   //  [num_blocks, num_heads, head_size, block_size]
812-     int  num_kv_heads,  //  [num_heads]
813-     float  scale,
811+         value_cache,        //  [num_blocks, num_heads, head_size, block_size]
812+     int64_t  num_kv_heads,  //  [num_heads]
813+     double  scale,
814814    torch::Tensor& block_tables,  //  [num_seqs, max_num_blocks_per_seq]
815815    torch::Tensor& seq_lens,      //  [num_seqs]
816-     int  block_size, int  max_seq_len,
816+     int64_t  block_size, int64_t  max_seq_len,
817817    const  c10::optional<torch::Tensor>& alibi_slopes,
818-     const  std::string& kv_cache_dtype, float  kv_scale, const  int  tp_rank,
819-     const  int  blocksparse_local_blocks, const  int  blocksparse_vert_stride,
820-     const  int  blocksparse_block_size, const  int  blocksparse_head_sliding_step) {
818+     const  std::string& kv_cache_dtype, double  kv_scale, const  int64_t  tp_rank,
819+     const  int64_t  blocksparse_local_blocks,
820+     const  int64_t  blocksparse_vert_stride, const  int64_t  blocksparse_block_size,
821+     const  int64_t  blocksparse_head_sliding_step) {
821822  const  bool  is_block_sparse = (blocksparse_vert_stride > 1 );
822823
823824  DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
@@ -972,16 +973,17 @@ void paged_attention_v2(
972973    torch::Tensor&
973974        key_cache,  //  [num_blocks, num_heads, head_size/x, block_size, x]
974975    torch::Tensor&
975-         value_cache,   //  [num_blocks, num_heads, head_size, block_size]
976-     int  num_kv_heads,  //  [num_heads]
977-     float  scale,
976+         value_cache,        //  [num_blocks, num_heads, head_size, block_size]
977+     int64_t  num_kv_heads,  //  [num_heads]
978+     double  scale,
978979    torch::Tensor& block_tables,  //  [num_seqs, max_num_blocks_per_seq]
979980    torch::Tensor& seq_lens,      //  [num_seqs]
980-     int  block_size, int  max_seq_len,
981+     int64_t  block_size, int64_t  max_seq_len,
981982    const  c10::optional<torch::Tensor>& alibi_slopes,
982-     const  std::string& kv_cache_dtype, float  kv_scale, const  int  tp_rank,
983-     const  int  blocksparse_local_blocks, const  int  blocksparse_vert_stride,
984-     const  int  blocksparse_block_size, const  int  blocksparse_head_sliding_step) {
983+     const  std::string& kv_cache_dtype, double  kv_scale, const  int64_t  tp_rank,
984+     const  int64_t  blocksparse_local_blocks,
985+     const  int64_t  blocksparse_vert_stride, const  int64_t  blocksparse_block_size,
986+     const  int64_t  blocksparse_head_sliding_step) {
985987  const  bool  is_block_sparse = (blocksparse_vert_stride > 1 );
986988  DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
987989                             CALL_V2_LAUNCHER_BLOCK_SIZE)
@@ -990,4 +992,4 @@ void paged_attention_v2(
990992#undef  WARP_SIZE
991993#undef  MAX
992994#undef  MIN
993- #undef  DIVIDE_ROUND_UP
995+ #undef  DIVIDE_ROUND_UP
0 commit comments