@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
519519
520520 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
521521 } break ;
522- case GGML_OP_SILU:
523- {
524- if (encoder == nil ) {
525- encoder = [command_buffer computeCommandEncoder ];
526- }
527-
528- [encoder setComputePipelineState: ctx->pipeline_silu];
529- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
530- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
531-
532- const int64_t n = ggml_nelements (dst);
533-
534- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
535- } break ;
536- case GGML_OP_RELU:
537- {
538- if (encoder == nil ) {
539- encoder = [command_buffer computeCommandEncoder ];
540- }
541-
542- [encoder setComputePipelineState: ctx->pipeline_relu];
543- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
544- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
545-
546- const int64_t n = ggml_nelements (dst);
547-
548- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
522+ case GGML_OP_UNARY:
523+ switch (ggml_get_unary_op (gf->nodes [i])) {
524+ case GGML_UNARY_OP_SILU:
525+ {
526+ if (encoder == nil ) {
527+ encoder = [command_buffer computeCommandEncoder ];
528+ }
529+
530+ [encoder setComputePipelineState: ctx->pipeline_silu];
531+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
532+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
533+
534+ const int64_t n = ggml_nelements (dst);
535+
536+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
537+ } break ;
538+ case GGML_UNARY_OP_RELU:
539+ {
540+ if (encoder == nil ) {
541+ encoder = [command_buffer computeCommandEncoder ];
542+ }
543+
544+ [encoder setComputePipelineState: ctx->pipeline_relu];
545+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
546+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
547+
548+ const int64_t n = ggml_nelements (dst);
549+
550+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
551+ } break ;
552+ case GGML_UNARY_OP_GELU:
553+ {
554+ if (encoder == nil ) {
555+ encoder = [command_buffer computeCommandEncoder ];
556+ }
557+
558+ [encoder setComputePipelineState: ctx->pipeline_gelu];
559+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
560+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
561+
562+ const int64_t n = ggml_nelements (dst);
563+
564+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
565+ } break ;
566+ default :
567+ {
568+ fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
569+ GGML_ASSERT (false );
570+ }
549571 } break ;
550- case GGML_OP_GELU:
551- {
552- if (encoder == nil ) {
553- encoder = [command_buffer computeCommandEncoder ];
554- }
555-
556- [encoder setComputePipelineState: ctx->pipeline_gelu];
557- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
558- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
559-
560- const int64_t n = ggml_nelements (dst);
561-
562- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
563- } break ;
564572 case GGML_OP_SOFT_MAX:
565573 {
566574 if (encoder == nil ) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
979987 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
980988 } break ;
981989 default :
982- fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
983- GGML_ASSERT (false );
990+ {
991+ fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
992+ GGML_ASSERT (false );
993+ }
984994 }
985995 }
986996
0 commit comments