@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
4040 ggml_kleidiai_kernels * kernels;
4141} static ctx = { CPU_FEATURE_NONE, NULL };
4242
43+ static const char * cpu_feature_to_string (cpu_feature f) {
44+ switch (f) {
45+ case CPU_FEATURE_NONE: return " NONE" ;
46+ case CPU_FEATURE_DOTPROD: return " DOTPROD" ;
47+ case CPU_FEATURE_I8MM: return " I8MM" ;
48+ case CPU_FEATURE_SVE: return " SVE" ;
49+ case CPU_FEATURE_SME: return " SME" ;
50+ default : return " UNKNOWN" ;
51+ }
52+ }
53+
4354static void init_kleidiai_context (void ) {
4455
4556 ggml_critical_section_start ();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
6273 ctx.features |= ggml_cpu_has_sme () ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
6374 }
6475 ctx.kernels = ggml_kleidiai_select_kernels_q4_0 (ctx.features );
76+ #ifndef NDEBUG
77+ if (ctx.kernels ) {
78+ GGML_LOG_DEBUG (" kleidiai: using kernel with CPU feature %s\n " , cpu_feature_to_string (ctx.kernels ->required_cpu ));
79+ }
80+ #endif
6581 }
6682 ggml_critical_section_end ();
6783}
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102118
103119class tensor_traits : public ggml ::cpu::tensor_traits {
104120 bool work_size (int /* n_threads */ , const struct ggml_tensor * op, size_t & size) override {
121+ if (op->op != GGML_OP_MUL_MAT) {
122+ return false ;
123+ }
105124 ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels (ctx.features , op);
106125 GGML_ASSERT (kernels);
107126 kernel_info * kernel = op->src [1 ]->ne [1 ] == 1 ? &kernels->gemv : &kernels->gemm ;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135154 } else if (dst->src [0 ]->type == GGML_TYPE_F16) {
136155 return compute_forward_kv_cache (params, dst);
137156 }
157+ } else if (dst->op == GGML_OP_GET_ROWS) {
158+ if (dst->src [0 ]->type == GGML_TYPE_Q4_0) {
159+ return compute_forward_get_rows (params, dst);
160+ }
138161 }
139162 return false ;
140163 }
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
270293 }
271294
272295 bool compute_forward_q4_0 (struct ggml_compute_params * params, struct ggml_tensor * dst) {
296+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_Q4_0);
297+
273298 const ggml_tensor * src0 = dst->src [0 ];
274299 const ggml_tensor * src1 = dst->src [1 ];
275300
@@ -342,26 +367,62 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342367 return true ;
343368 }
344369
370+ bool compute_forward_get_rows (struct ggml_compute_params * params, struct ggml_tensor * dst) {
371+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_Q4_0);
372+ GGML_ASSERT (ctx.kernels );
373+
374+ const ggml_tensor * src0 = dst->src [0 ];
375+ const ggml_tensor * src1 = dst->src [1 ];
376+
377+ GGML_TENSOR_BINARY_OP_LOCALS
378+
379+ rhs_packing_info * rhs_info = &ctx.kernels ->rhs_info ;
380+ kernel_info * kernel = &ctx.kernels ->gemm ;
381+
382+ const int64_t nc = ne00;
383+ const int64_t nr = ggml_nelements (src1);
384+
385+ const size_t block_rows = kernel->get_nr ();
386+ const size_t kr = kernel->get_kr ();
387+
388+ const size_t num_bytes_multiplier = sizeof (uint16_t );
389+ const size_t packed_stride = rhs_info->packed_stride (nc, block_rows, kr, QK4_0);
390+
391+ const int ith = params->ith ;
392+ const int nth = params->nth ;
393+
394+ const int dr = (nr + nth - 1 ) / nth;
395+ const int ir0 = dr * ith;
396+ const int ir1 = MIN (ir0 + dr, nr);
397+
398+ for (int64_t i = ir0; i < ir1; ++i) {
399+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
400+ int64_t row_idx = ((const int32_t *)src1->data )[i];
401+ GGML_ASSERT (row_idx >= 0 && row_idx < src0->ne [1 ]);
402+
403+ float *out = (float *)((char *)dst->data + i * nb1);
404+ rhs_info->to_float (src0->data , row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
405+ }
406+
407+ return true ;
408+ }
409+
345410public:
346411 int repack (struct ggml_tensor * tensor, const void * data, size_t data_size) {
412+ GGML_ASSERT (tensor->type == GGML_TYPE_Q4_0);
347413 GGML_ASSERT (ctx.kernels );
348414 const size_t n = tensor->ne [1 ];
349415 const size_t k = tensor->ne [0 ];
350416 size_t nr = ctx.kernels ->gemm .get_nr ();
351417 size_t kr = ctx.kernels ->gemm .get_kr ();
352418 size_t sr = ctx.kernels ->gemm .get_sr ();
353419
354- #ifndef NDEBUG
355- const size_t repacked_size = variant_call<size_t >(ctx.kernels ->rhs_info .packed_size , n, k, nr, kr, QK4_0);
356- GGML_ASSERT (repacked_size <= data_size && " repacked size larger than the packed size!" );
357- #endif
358420 struct kai_rhs_pack_qs4cxs1s0_param params;
359421 params.lhs_zero_point = 1 ;
360422 params.rhs_zero_point = 8 ;
361423 variant_call<void >(ctx.kernels ->rhs_info .pack_func , 1 , n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, nullptr , tensor->data , 0 , ¶ms);
362424
363425 return 0 ;
364-
365426 GGML_UNUSED (data_size);
366427 }
367428};
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375436static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376437 tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits (buffer, tensor);
377438
378- GGML_UNUSED (buffer);
379439 return GGML_STATUS_SUCCESS;
440+ GGML_UNUSED (buffer);
380441}
381442
382443static void ggml_backend_cpu_kleidiai_buffer_set_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418479 GGML_UNUSED (buft);
419480}
420481
482+ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
483+ GGML_ASSERT (tensor->type == GGML_TYPE_Q4_0);
484+ GGML_ASSERT (ctx.kernels );
485+
486+ const size_t n = tensor->ne [1 ];
487+ const size_t k = tensor->ne [0 ];
488+ const size_t nr = ctx.kernels ->gemm .get_nr ();
489+ const size_t kr = ctx.kernels ->gemm .get_kr ();
490+
491+ return variant_call<size_t >(ctx.kernels ->rhs_info .packed_size , n, k, nr, kr, QK4_0);
492+
493+ GGML_UNUSED (buft);
494+ }
495+
421496namespace ggml ::cpu::kleidiai {
422497class extra_buffer_type : ggml::cpu::extra_buffer_type {
423498 bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
424- if (op->op == GGML_OP_MUL_MAT &&
499+ if (( op->op == GGML_OP_MUL_MAT || op-> op == GGML_OP_GET_ROWS) &&
425500 op->src [0 ]->type == GGML_TYPE_Q4_0 &&
426501 op->src [0 ]->buffer &&
427502 (ggml_n_dims (op->src [0 ]) == 2 ) &&
428503 op->src [0 ]->buffer ->buft == ggml_backend_cpu_kleidiai_buffer_type () && ctx.kernels ) {
504+ if (op->op == GGML_OP_GET_ROWS && op->src [1 ]->ne [0 ] != 8 ) {
505+ return false ;
506+ }
429507 if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
430508 return false ;
431509 }
432- if (op->src [1 ]->type == GGML_TYPE_F32 &&
510+ if (( op->src [1 ]->type == GGML_TYPE_F32 || op-> src [ 1 ]-> type == GGML_TYPE_I32) &&
433511 ggml_ne (op->src [1 ], 2 ) == 1 && ggml_ne (op->src [1 ], 3 ) == 1 ) {
434512 return true ;
435513 }
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438516 }
439517
440518 ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
441- if (op->op == GGML_OP_MUL_MAT) {
519+ if (op->op == GGML_OP_MUL_MAT || op-> op == GGML_OP_GET_ROWS ) {
442520 if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_kleidiai_buffer_type ()) {
443521 return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
444522 }
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469547 /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470548 /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471549 /* .get_max_size = */ nullptr , // defaults to SIZE_MAX
472- /* .get_alloc_size = */ nullptr , // defaults to ggml_nbytes
550+ /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
473551 /* .is_host = */ nullptr ,
474552 },
475553 /* .device = */ ggml_backend_reg_dev_get (ggml_backend_cpu_reg (), 0 ),
0 commit comments