@@ -930,13 +930,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
930930    return  res;
931931}
932932
933+ ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_flash_attn_ext_pad (
934+         ggml_metal_library_t  lib,
935+         const  struct  ggml_tensor  * op,
936+         bool     has_mask,
937+         int32_t  ncpsg) {
938+     assert (op->op  == GGML_OP_FLASH_ATTN_EXT);
939+     GGML_UNUSED (op);
940+ 
941+     char  base[256 ];
942+     char  name[256 ];
943+ 
944+     snprintf (base, 256 , " kernel_%s"  ,
945+             " flash_attn_ext_pad"  );
946+ 
947+     snprintf (name, 256 , " %s_mask=%d_ncpsg=%d"  ,
948+             base,
949+             has_mask,
950+             ncpsg);
951+ 
952+     ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
953+     if  (res) {
954+         return  res;
955+     }
956+ 
957+     ggml_metal_cv_t  cv = ggml_metal_cv_init ();
958+ 
959+     ggml_metal_cv_set_bool (cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0 );
960+   // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
961+   // ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);
962+   // ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);
963+ 
964+   // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
965+   // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
966+   // ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
967+   // ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
968+     ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
969+ 
970+     res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
971+ 
972+     ggml_metal_cv_free (cv);
973+ 
974+     return  res;
975+ }
976+ 
933977ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_flash_attn_ext (
934978        ggml_metal_library_t  lib,
935979        const  ggml_tensor * op,
936980        bool     has_mask,
937981        bool     has_sinks,
938982        bool     has_bias,
939983        bool     has_scap,
984+         bool     has_kvpad,
940985        int32_t  nsg) {
941986    assert (op->op  == GGML_OP_FLASH_ATTN_EXT);
942987
@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9551000            dk,
9561001            dv);
9571002
958-     snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d"  ,
1003+     snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=% d_ns10=%d_ns20=%d_nsg=%d"  ,
9591004            base,
9601005            has_mask,
9611006            has_sinks,
9621007            has_bias,
9631008            has_scap,
1009+             has_kvpad,
9641010            ns10,
9651011            ns20,
9661012            nsg);
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9761022    ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
9771023    ggml_metal_cv_set_bool (cv, has_bias,  FC_FLASH_ATTN_EXT + 2 );
9781024    ggml_metal_cv_set_bool (cv, has_scap,  FC_FLASH_ATTN_EXT + 3 );
1025+     ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
9791026
9801027    ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
9811028    ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9951042        bool     has_sinks,
9961043        bool     has_bias,
9971044        bool     has_scap,
1045+         bool     has_kvpad,
9981046        int32_t  nsg,
9991047        int32_t  nwg) {
10001048    assert (op->op  == GGML_OP_FLASH_ATTN_EXT);
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10141062            dk,
10151063            dv);
10161064
1017-     snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d"  ,
1065+     snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d"  ,
10181066            base,
10191067            has_mask,
10201068            has_sinks,
10211069            has_bias,
10221070            has_scap,
1071+             has_kvpad,
10231072            ns10,
10241073            ns20,
10251074            nsg, nwg);
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10351084    ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
10361085    ggml_metal_cv_set_bool (cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2 );
10371086    ggml_metal_cv_set_bool (cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3 );
1087+     ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
10381088
10391089    ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
10401090    ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments