@@ -77,7 +77,8 @@ class RunnerBase
77
77
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
78
78
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
79
79
torch::optional<torch::Tensor> mla_context_paged_kv,
80
- torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const
80
+ torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
81
+ torch::optional<torch::Tensor> softmax_stats_tensor) const
81
82
= 0;
82
83
};
83
84
@@ -127,7 +128,8 @@ class Runner : public RunnerBase
127
128
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
128
129
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
129
130
torch::optional<torch::Tensor> mla_context_paged_kv,
130
- torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const override
131
+ torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
132
+ torch::optional<torch::Tensor> softmax_stats_tensor) const override
131
133
{
132
134
auto stream = at::cuda::getCurrentCUDAStream (qkv.get_device ());
133
135
T* attention_input = static_cast <T*>(qkv.slice (0 , token_offset).data_ptr ());
@@ -279,6 +281,11 @@ class Runner : public RunnerBase
279
281
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
280
282
enqueue_params.host_block_offsets = host_block_offsets;
281
283
enqueue_params.batch_size = num_seqs;
284
+ if (softmax_stats_tensor.has_value ())
285
+ {
286
+ enqueue_params.softmaxStatsPtr = static_cast <float2*>(softmax_stats_tensor.value ().data_ptr ());
287
+ }
288
+
282
289
if (op.isMLAEnabled ())
283
290
{
284
291
mla_params.cache_seq_lens = sequence_lengths_ptr;
@@ -385,7 +392,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
385
392
std::optional<int64_t > qk_rope_head_dim, std::optional<int64_t > v_head_dim,
386
393
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
387
394
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
388
- std::optional<int64_t > attention_chunk_size)
395
+ std::optional<int64_t > attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor )
389
396
{
390
397
TLLM_LOG_TRACE (" Attention op starts at layer %d" , layer_idx);
391
398
// Use these tensors to infer if the attention is using KV cache
@@ -603,7 +610,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
603
610
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
604
611
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
605
612
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
606
- mla_context_kv_cache_block_offsets);
613
+ mla_context_kv_cache_block_offsets, softmax_stats_tensor );
607
614
}
608
615
609
616
if ((num_generations > 0 ) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -619,7 +626,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
619
626
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
620
627
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
621
628
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
622
- mla_context_kv_cache_block_offsets);
629
+ mla_context_kv_cache_block_offsets, softmax_stats_tensor );
623
630
}
624
631
625
632
TLLM_LOG_TRACE (" Attention op stops at layer %d" , layer_idx);
@@ -742,6 +749,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
742
749
" , Tensor? mla_context_paged_kv"
743
750
" , Tensor? mla_context_kv_cache_block_offsets"
744
751
" , int? attention_chunk_size"
752
+ " , Tensor? softmax_stats_tensor"
745
753
" ) -> ()" );
746
754
747
755
m.def (" attention_supports_nvfp4_output" , &torch_ext::attention_supports_nvfp4_output);
0 commit comments