Skip to content

Commit 4ffc1ff

Browse files
authored
DMMHA: add unit tests; fix CPU, CUDA kernel (microsoft#22567)
### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests
1 parent 2e4e221 commit 4ffc1ff

File tree

7 files changed

+381
-367
lines changed

7 files changed

+381
-367
lines changed

onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class AttentionCPUBase : public AttentionBase {
7777
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
7878
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
7979
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
80-
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
80+
causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_);
8181
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
8282
}
8383

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ void PrepareMask(const int32_t* mask_index,
120120
bool causal,
121121
int batch_size,
122122
int sequence_length,
123+
int kv_sequence_length,
123124
int past_sequence_length,
124125
float mask_filter_value) {
125-
const int all_sequence_length = past_sequence_length + sequence_length;
126+
const int all_sequence_length = past_sequence_length + kv_sequence_length;
126127

127128
// mask_data has been filled with 0, and its shape is BxSxT
128129
T* p_mask = mask_data;

onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
339339
T* attention_probs_ptr = reinterpret_cast<T*>(attention_probs) + last_offset;
340340
math::Dot<float, CPUMathUtil>(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr);
341341

342+
*attention_probs_ptr *= scale;
342343
// Apply the attention bias and mask
343344
if (attn_bias_data != nullptr) {
344345
*attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length];
@@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
348349
if (is_masked) {
349350
*attention_probs_ptr += mask_filter_value_;
350351
}
351-
*attention_probs_ptr *= scale;
352352
}
353353

354354
{
@@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
362362
const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size;
363363
T* output = reinterpret_cast<T*>(attention_probs) + j + i * probs_matrix_size;
364364
math::Dot<float, CPUMathUtil>(head_size, q_vec, past_k_vec, output, nullptr);
365+
366+
*output *= scale;
365367
// Apply the attention bias and mask
366368
if (attn_bias_data != nullptr) {
367369
*output += attn_bias_data[attn_bias_base_offset + j];
@@ -371,11 +373,11 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
371373
if (is_masked) {
372374
*output += mask_filter_value_;
373375
}
374-
*output *= scale;
375376
}
376377
}
377378
// Append current key to present key (past_present_share_buffer_ is true)
378-
memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T));
379+
memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size,
380+
K + i * head_size, head_size * sizeof(T));
379381
}
380382
});
381383

@@ -460,7 +462,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeVxAttentionScoreWithBeams(
460462
}
461463
}
462464
// Append current value to present value (past_present_share_buffer_ is true)
463-
memcpy(present_value_data + i * max_sequence_length * v_head_size,
465+
memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size,
464466
V + i * v_head_size,
465467
v_head_size * sizeof(T));
466468
}

onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
3333
const Tensor* cache_indir,
3434
OpKernelContext* context,
3535
int beam_width,
36-
Tensor* scaled_qk = nullptr) const;
36+
Tensor* output_qk = nullptr) const;
3737
void ComputeAttentionProbsWithBeams(T* attention_probs,
3838
const T* Q,
3939
const T* K,
@@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
5050
bool broadcast_attn_bias_dim_1,
5151
const int32_t* cache_indir_data,
5252
int beam_width,
53-
T* scaled_qk_data = nullptr) const;
53+
T* output_qk_data = nullptr) const;
5454
void ComputeVxAttentionScoreWithBeams(T* output,
5555
T* tmp_buffer,
5656
const T* attention_probs,

onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
298298
if (params.attention_bias != nullptr) {
299299
qk = add_vec(qk, reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + tlength]);
300300
}
301+
if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) {
302+
qk += params.mask_filter_value;
303+
}
301304
qk_max = qk;
302305
qk_smem[tlength] = qk;
303306
}
@@ -534,7 +537,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
534537

535538
if (params.out_qk != nullptr) {
536539
// store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length]
537-
float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength);
540+
float* target = (reinterpret_cast<float*>(params.out_qk)) + (static_cast<int64_t>(bhi) * (sum_tlength + 1));
538541
for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
539542
target[ti] = (float)(qk_smem[ti]);
540543
}

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
908908
OpSchema::Optional)
909909
.Input(9,
910910
"cache_indirection",
911-
// This input is useful for CUDA EP only.
912911
"A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies "
913912
"which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration",
914913
"M",

0 commit comments

Comments
 (0)