3030from vllm .v1 .worker .gpu_input_batch import InputBatch
3131
3232from vllm_ascend .ops .attention import vanilla_chunked_prefill
33+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
34+ nd_to_nz_2d , nd_to_nz_spec )
3335
3436
3537class AscendAttentionBackend (AttentionBackend ):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264 num_kv_heads : int ,
6365 head_size : int ,
6466 ) -> Tuple [int , ...]:
67+ if is_310p ():
68+ return (2 , num_blocks , num_kv_heads * head_size // 16 , block_size ,
69+ 16 )
6570 return (2 , num_blocks , block_size , num_kv_heads , head_size )
6671
6772 @staticmethod
@@ -166,6 +171,16 @@ def build(self,
166171 query_start_loc = query_start_loc_cpu .to (self .runner .device ,
167172 non_blocking = True )
168173
174+ if is_310p ():
175+ if attn_state == AscendAttentionState .PrefillNoCache :
176+ mask_nz = nd_to_nz_2d (attn_mask )
177+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
178+ ACL_FORMAT_FRACTAL_NZ )
179+ elif attn_state == AscendAttentionState .ChunkedPrefill :
180+ mask_nz = nd_to_nz_spec (attn_mask )
181+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
182+ ACL_FORMAT_FRACTAL_NZ )
183+
169184 attn_metadata = AscendMetadata (
170185 num_actual_tokens = num_actual_tokens ,
171186 block_tables = block_table ,
@@ -249,6 +264,7 @@ def forward(
249264 self .head_size ,
250265 dtype = query .dtype ,
251266 device = query .device )
267+ ori_output = output
252268 if trace_flag :
253269 torch .ops .vllm .unified_ascend_attention_with_output (
254270 query = query ,
@@ -293,6 +309,18 @@ def forward(
293309 assert attn_metadata is not None
294310 assert attn_metadata .attn_mask is not None
295311 mask = attn_metadata .attn_mask
312+ if is_310p ():
313+ # align q k v output tensors
314+ query = aligned_16 (query )
315+ key = aligned_16 (key )
316+ value = aligned_16 (value )
317+ output = aligned_16 (output )
318+
319+ # do reformat in case of broadcasted tensors
320+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
321+ mask = torch_npu .npu_format_cast (mask .contiguous (),
322+ ACL_FORMAT_FRACTAL_NZ )
323+
296324 torch_npu ._npu_flash_attention (query = query ,
297325 key = key ,
298326 value = value ,
@@ -302,6 +330,7 @@ def forward(
302330 num_heads = self .num_heads ,
303331 num_kv_heads = self .num_kv_heads ,
304332 out = output )
333+ output = output [:num_tokens , :, :]
305334 elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
306335 assert attn_metadata is not None
307336 assert attn_metadata .attn_mask is not None
@@ -319,6 +348,10 @@ def forward(
319348 scale_value = self .scale ,
320349 out = output )
321350 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
351+ if is_310p ():
352+ # # seq_lens_tensor needs to be transferred to the device for 310P
353+ attn_metadata .seq_lens = \
354+ attn_metadata .seq_lens .to (device = query .device )
322355 torch_npu ._npu_paged_attention (
323356 query = query ,
324357 key_cache = self .key_cache ,
@@ -352,6 +385,14 @@ def forward(
352385 self .scale , None , True )
353386 else :
354387 # use paged attention
388+ assert attn_metadata is not None
389+ assert attn_metadata .attn_mask is not None
390+ if is_310p ():
391+ # do reformat in case of broadcasted tensors
392+ attn_metadata .attn_mask = \
393+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
394+ attn_metadata .seq_lens = \
395+ attn_metadata .seq_lens .to (device = query .device )
355396 torch_npu ._npu_paged_attention_splitfuse (
356397 query = query ,
357398 key_cache = self .key_cache ,
@@ -364,6 +405,10 @@ def forward(
364405 num_heads = self .num_heads ,
365406 scale_value = self .scale ,
366407 out = output )
408+
409+ # to make in-place change to the output tensor
410+ if not id (ori_output ) == id (output ):
411+ ori_output [:, :, :] = output [:num_tokens , :, :]
367412 return output .view (num_tokens , self .hidden_size )
368413
369414
0 commit comments