@@ -377,29 +377,27 @@ int llama_mtl_eval(
377377 id <MTLBuffer > id_src1 = llama_mtl_get_buffer (ctx, gf->nodes [i]->src1 , &offs_src1);
378378 id <MTLBuffer > id_dst = llama_mtl_get_buffer (ctx, gf->nodes [i], &offs_dst);
379379
380- const int64_t ncols0 = gf->nodes [i]->src0 ->ne [0 ];
381- const int64_t nrows0 = gf->nodes [i]->src0 ->ne [1 ];
382-
383- const int64_t ncols1 = gf->nodes [i]->src1 ->ne [0 ];
384- const int64_t nrows1 = gf->nodes [i]->src1 ->ne [1 ];
385-
386- const int64_t ncols = gf->nodes [i]->ne [0 ];
387- const int64_t nrows = gf->nodes [i]->ne [1 ];
380+ const int64_t ne00 = gf->nodes [i]->src0 ->ne [0 ];
381+ const int64_t ne01 = gf->nodes [i]->src0 ->ne [1 ];
382+ const int64_t ne10 = gf->nodes [i]->src1 ->ne [0 ];
383+ const int64_t ne11 = gf->nodes [i]->src1 ->ne [1 ];
384+ const int64_t ne0 = gf->nodes [i]->ne [0 ];
385+ const int64_t ne1 = gf->nodes [i]->ne [1 ];
388386
389387 [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0];
390388 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
391389 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
392390 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
393- [encoder setBytes: &ncols0 length: sizeof (ncols0 ) atIndex: 3 ];
394- [encoder setBytes: &nrows0 length: sizeof (nrows0 ) atIndex: 4 ];
395- [encoder setBytes: &ncols1 length: sizeof (ncols1 ) atIndex: 5 ];
396- [encoder setBytes: &nrows1 length: sizeof (nrows1 ) atIndex: 6 ];
397- [encoder setBytes: &ncols length: sizeof (ncols ) atIndex: 7 ];
398- [encoder setBytes: &nrows length: sizeof (nrows ) atIndex: 8 ];
391+ [encoder setBytes: &ne00 length: sizeof (ne00 ) atIndex: 3 ];
392+ [encoder setBytes: &ne00 length: sizeof (ne00 ) atIndex: 4 ];
393+ [encoder setBytes: &ne11 length: sizeof (ne11 ) atIndex: 5 ];
394+ [encoder setBytes: &ne11 length: sizeof (ne11 ) atIndex: 6 ];
395+ [encoder setBytes: &ne0 length: sizeof (ne0 ) atIndex: 7 ];
396+ [encoder setBytes: &ne1 length: sizeof (ne1 ) atIndex: 8 ];
399397
400- printf (" mul_mat: %lld x%lld * %lld x%lld -> %lld x%lld \n " , ncols0, nrows0, ncols1, nrows1, ncols, nrows );
398+ printf (" mul_mat: %lld x%lld * %lld x%lld -> %lld x%lld \n " , ne00, ne01, ne10, ne11, ne0, ne1 );
401399
402- [encoder dispatchThreadgroups: MTLSizeMake (nrows0, nrows1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
400+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
403401 } break ;
404402 case GGML_OP_GET_ROWS:
405403 {
0 commit comments