@@ -2106,10 +2106,16 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
21062106 const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
21072107 use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , is_mul_mat_id ? src1->ne [2 ] : src1->ne [1 ]);
21082108
2109- if (tensor->op == GGML_OP_MUL_MAT_ID) {
2110- use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne [2 ] == 1 ;
2109+ // we only support fusion for ncols_dst = 1
2110+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne [1 ] != 1 ) {
2111+ return false ;
21112112 }
21122113
2114+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne [2 ] != 1 ) {
2115+ return false ;
2116+ }
2117+
2118+
21132119 return use_mul_mat_vec_f;
21142120}
21152121
@@ -2125,8 +2131,13 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
21252131 bool use_mul_mat_vec_q = ggml_is_quantized (src0->type ) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
21262132 dst->type == GGML_TYPE_F32 && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
21272133
2128- if (tensor->op == GGML_OP_MUL_MAT_ID) {
2129- use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne [2 ] == 1 ;
2134+ // we only support fusion for ncols_dst = 1
2135+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne [1 ] != 1 ) {
2136+ return false ;
2137+ }
2138+
2139+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne [2 ] != 1 ) {
2140+ return false ;
21302141 }
21312142
21322143 return use_mul_mat_vec_q;
@@ -2979,12 +2990,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29792990 }
29802991 }
29812992
2982- std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
2993+ std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
29832994 std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
29842995
29852996 std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2986-
2987- std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2997+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
29882998
29892999 if (ops.size () == 5 && (ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }) ||
29903000 ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }))) {
0 commit comments