@@ -463,6 +463,7 @@ struct vk_device_struct {
463463
464464 vk_pipeline pipeline_leaky_relu_f32;
465465 vk_pipeline pipeline_silu_back_f32;
466+ vk_pipeline pipeline_geglu_back_f32;
466467 vk_pipeline pipeline_diag_mask_inf_f32;
467468 vk_pipeline pipeline_cross_entropy_loss_back_f32;
468469 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -2914,6 +2915,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29142915 ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
29152916 ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
29162917
2918+ ggml_vk_create_pipeline(device, device->pipeline_geglu_back_f32, "geglu_back_f32", geglu_back_f32_len, geglu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2919+
29172920 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
29182921
29192922 ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6628,6 +6631,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
66286631 return ctx->device->pipeline_silu_back_f32;
66296632 }
66306633 return nullptr;
6634+ case GGML_OP_GEGLU_BACK:
6635+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6636+ return ctx->device->pipeline_geglu_back_f32;
6637+ }
6638+ return nullptr;
66316639 case GGML_OP_NORM:
66326640 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
66336641 return ctx->device->pipeline_norm_f32;
@@ -7761,6 +7769,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
77617769 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
77627770}
77637771
7772+ static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7773+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7774+ }
7775+
77647776static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
77657777 float * op_params = (float *)dst->op_params;
77667778
@@ -7835,12 +7847,12 @@ static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_co
78357847 const int64_t nclasses = src1->ne[0];
78367848 const int64_t nrows = ggml_nrows(src1);
78377849
7838- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, {
7850+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, {
78397851 (uint32_t)nclasses,
78407852 (uint32_t)nrows,
78417853 0.0f,
78427854 0.0f
7843- }, dryrun);
7855+ }, dryrun);
78447856}
78457857
78467858static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9112,6 +9124,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91129124 case GGML_OP_CONT:
91139125 case GGML_OP_DUP:
91149126 case GGML_OP_SILU_BACK:
9127+ case GGML_OP_GEGLU_BACK:
91159128 case GGML_OP_NORM:
91169129 case GGML_OP_GROUP_NORM:
91179130 case GGML_OP_RMS_NORM:
@@ -9181,6 +9194,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91819194 case GGML_OP_CONT:
91829195 case GGML_OP_DUP:
91839196 case GGML_OP_SILU_BACK:
9197+ case GGML_OP_GEGLU_BACK:
91849198 case GGML_OP_NORM:
91859199 case GGML_OP_GROUP_NORM:
91869200 case GGML_OP_RMS_NORM:
@@ -9303,6 +9317,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93039317 case GGML_OP_SILU_BACK:
93049318 ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
93059319
9320+ break;
9321+ case GGML_OP_GEGLU_BACK:
9322+ ggml_vk_geglu_back(ctx, compute_ctx, src0, src1, node, dryrun);
9323+
93069324 break;
93079325 case GGML_OP_NORM:
93089326 ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -9362,7 +9380,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93629380 ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
93639381
93649382 break;
9365-
9383+
93669384 case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
93679385 ggml_vk_cross_entropy_loss_back(ctx, compute_ctx, src0, src1, src2, node, dryrun);
93689386
@@ -9524,6 +9542,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
95249542 case GGML_OP_CONT:
95259543 case GGML_OP_DUP:
95269544 case GGML_OP_SILU_BACK:
9545+ case GGML_OP_GEGLU_BACK:
95279546 case GGML_OP_NORM:
95289547 case GGML_OP_GROUP_NORM:
95299548 case GGML_OP_RMS_NORM:
@@ -10693,6 +10712,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1069310712 (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
1069410713 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1069510714 case GGML_OP_SILU_BACK:
10715+ case GGML_OP_GEGLU_BACK:
1069610716 case GGML_OP_RMS_NORM_BACK:
1069710717 case GGML_OP_SQR:
1069810718 case GGML_OP_SIN:
0 commit comments