@@ -3637,7 +3637,7 @@ struct test_flash_attn_ext : public test_case {
36373637
36383638 ggml_tensor * m = nullptr ;
36393639 if (mask) {
3640- m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[1 ], 1 );
3640+ m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[0 ], nr23[ 1 ] );
36413641 ggml_set_name (m, " m" );
36423642 }
36433643
@@ -4751,7 +4751,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47514751 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {1 , 1 }, scale, max_bias));
47524752
47534753 if (ne0 <= 32 && ne1 <= 32 ) {
4754- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4754+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 3 }, mask, m_prec, {3 , 1 }, scale, max_bias));
47554755 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {2 , 3 }, scale, max_bias));
47564756 }
47574757 }
0 commit comments