diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 452f6489863..0a062c5aed9 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -838,8 +838,10 @@ def prepare_flash_mla(self) -> None: block_ids_per_seq = self.kv_cache_manager.get_block_ids_per_seq( self.request_ids).pin_memory() num_blocks = block_ids_per_seq.shape[1] + self.kv_block_ids_per_seq.fill_(0) self.kv_block_ids_per_seq[:self.num_seqs, :num_blocks].copy_( block_ids_per_seq, non_blocking=True) + self.block_ids_per_seq.fill_(0) self.block_ids_per_seq[:self.num_generations, :num_blocks].copy_( block_ids_per_seq[self.num_contexts:], non_blocking=True)