4141 GGML_METAL_KERNEL_TYPE_TANH,
4242 GGML_METAL_KERNEL_TYPE_RELU,
4343 GGML_METAL_KERNEL_TYPE_GELU,
44+ GGML_METAL_KERNEL_TYPE_GELU_4,
4445 GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
4547 GGML_METAL_KERNEL_TYPE_SILU,
48+ GGML_METAL_KERNEL_TYPE_SILU_4,
4649 GGML_METAL_KERNEL_TYPE_SOFT_MAX,
4750 GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
4851 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -473,8 +476,11 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
473476 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TANH, tanh, true );
474477 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RELU, relu, true );
475478 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU, gelu, true );
479+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true );
476480 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true );
481+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true );
477482 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
483+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
478484 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction );
479485 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction );
480486 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
@@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute(
11781184 } break ;
11791185 case GGML_OP_UNARY:
11801186 switch (ggml_get_unary_op (gf->nodes [i])) {
1187+ // we are not taking into account the strides, so for now require contiguous tensors
1188+ GGML_ASSERT (ggml_is_contiguous (src0));
1189+
11811190 case GGML_UNARY_OP_TANH:
11821191 {
11831192 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TANH].pipeline ;
@@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
12041213 } break ;
12051214 case GGML_UNARY_OP_GELU:
12061215 {
1207- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU].pipeline ;
1216+ int64_t n = ggml_nelements (dst);
1217+
1218+ id <MTLComputePipelineState > pipeline = nil ;
1219+
1220+ if (n % 4 == 0 ) {
1221+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_4].pipeline ;
1222+ n /= 4 ;
1223+ } else {
1224+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU].pipeline ;
1225+ }
12081226
12091227 [encoder setComputePipelineState: pipeline];
12101228 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
12111229 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
12121230
1213- const int64_t n = ggml_nelements (dst);
1214- GGML_ASSERT (n % 4 == 0 );
1215-
1216- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1231+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
12171232 } break ;
12181233 case GGML_UNARY_OP_GELU_QUICK:
12191234 {
1220- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline ;
1235+ int64_t n = ggml_nelements (dst);
1236+
1237+ id <MTLComputePipelineState > pipeline = nil ;
1238+
1239+ if (n % 4 == 0 ) {
1240+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline ;
1241+ n /= 4 ;
1242+ } else {
1243+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline ;
1244+ }
12211245
12221246 [encoder setComputePipelineState: pipeline];
12231247 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
12241248 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
12251249
1226- const int64_t n = ggml_nelements (dst);
1227- GGML_ASSERT (n % 4 == 0 );
1228-
1229- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1250+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
12301251 } break ;
12311252 case GGML_UNARY_OP_SILU:
12321253 {
1233- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU].pipeline ;
1254+ int64_t n = ggml_nelements (dst);
1255+
1256+ id <MTLComputePipelineState > pipeline = nil ;
1257+
1258+ if (n % 4 == 0 ) {
1259+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU_4].pipeline ;
1260+ n /= 4 ;
1261+ } else {
1262+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU].pipeline ;
1263+ }
12341264
12351265 [encoder setComputePipelineState: pipeline];
12361266 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
12371267 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
12381268
1239- const int64_t n = ggml_nelements (dst);
1240- GGML_ASSERT (n % 4 == 0 );
1241-
1242- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1269+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
12431270 } break ;
12441271 default :
12451272 {
0 commit comments