@@ -483,34 +483,26 @@ void main() {
483483 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
484484 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
485485
486- const uint ib = idx / 128; // 2 values per idx
487- const uint ib32 = (idx % 128) / 16; // 0..7
488- const uint ib8 = (idx % 128) / 4;
489- const int i8 = 2 * int(idx % 4);
486+ const uint ib = idx / 32; // 8 values per idx
487+ const uint ib32 = (idx % 32) / 4; // 0..7
488+ const uint ib8 = idx % 32;
490489
491490 const float d = float(data_a[ib].d);
492491 const uint qh = data_a[ib].qh[ib32];
493492 const uint qs = data_a[ib].qs[ib8];
494493 const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
495494 const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
496495 const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
497-
498- const ivec2 gvec = ivec2(
499- bitfieldExtract(grid, 2 * (i8), 2),
500- bitfieldExtract(grid, 2 * (i8 + 1), 2)
501- );
502- const vec2 v = dl * (vec2(gvec) + delta);
503-
504- buf_a[buf_idx ] = BUF_TYPE(v.x);
505- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
496+ [[unroll]] for (int k = 0; k < 8; ++k) {
497+ buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
498+ }
506499#elif defined(DATA_A_IQ1_M)
507500 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
508501 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
509502
510- const uint ib = idx / 128 ; // 2 values per idx
511- const uint ib8 = ( idx % 128) / 4 ;
503+ const uint ib = idx / 32 ; // 8 values per idx
504+ const uint ib8 = idx % 32 ;
512505 const uint ib16 = ib8 / 2;
513- const int i8 = 2 * int(idx % 4);
514506
515507 const uint16_t[4] scales = data_a[ib].scales;
516508 const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -521,21 +513,16 @@ void main() {
521513 const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
522514 const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
523515 const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
524- const ivec2 gvec = ivec2(
525- bitfieldExtract(grid, 2 * (i8), 2),
526- bitfieldExtract(grid, 2 * (i8 + 1), 2)
527- );
528- const vec2 v = dl * (vec2(gvec) + delta);
529-
530- buf_a[buf_idx ] = BUF_TYPE(v.x);
531- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
516+ [[unroll]] for (int k = 0; k < 8; ++k) {
517+ buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
518+ }
532519#elif defined(DATA_A_IQ2_XXS)
533520 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
534521 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
535522
536- const uint ib = idx / 128 ; // 2 values per idx
537- const uint ib32 = (idx % 128 ) / 16 ; // 0..7
538- const uint ib8 = ( idx / 4) % 4;
523+ const uint ib = idx / 32 ; // 8 values per idx
524+ const uint ib32 = (idx % 32 ) / 4 ; // 0..7
525+ const uint ib8 = idx % 4;
539526
540527 const float d = float(data_a[ib].d);
541528 const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -545,63 +532,81 @@ void main() {
545532 data_a[ib].qs[8*ib32 + 6],
546533 data_a[ib].qs[8*ib32 + 7]
547534 ));
548- const float db = d * 0.25 * (0.5 + (signs >> 28));
535+ const BUF_TYPE db = BUF_TYPE( d * 0.25 * (0.5 + (signs >> 28) ));
549536 const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
550- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
551- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
552- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
553- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
554-
555- buf_a[buf_idx ] = BUF_TYPE(v.x);
556- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
537+ const uint sign = sign7 | (bitCount(sign7) << 7);
538+ const uvec2 grid = iq2xxs_grid[qs];
539+ const vec4 grid0 = vec4(unpack8(grid.x));
540+ const vec4 grid1 = vec4(unpack8(grid.y));
541+
542+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
543+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
544+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
545+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
546+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
547+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
548+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
549+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
557550#elif defined(DATA_A_IQ2_XS)
558551 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
559552 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
560553
561- const uint ib = idx / 128 ; // 2 values per idx
562- const uint ib32 = (idx % 128 ) / 16; // 0..7
563- const uint ib8 = ( idx / 4) % 4; // 0..3
554+ const uint ib = idx / 32 ; // 8 values per idx
555+ const uint ib32 = (idx % 32 ) / 4; // 0..7
556+ const uint ib8 = idx % 4; // 0..3
564557
565558 const float d = float(data_a[ib].d);
566559 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
567- const float db = d * 0.25 * (0.5 + scale);
560+ const BUF_TYPE db = BUF_TYPE( d * 0.25 * (0.5 + scale) );
568561 const uint qs = data_a[ib].qs[4 * ib32 + ib8];
569562 const uint sign7 = qs >> 9;
570- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
571- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
572- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
573- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
574-
575- buf_a[buf_idx ] = BUF_TYPE(v.x);
576- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
563+ const uint sign = sign7 | (bitCount(sign7) << 7);
564+ const uvec2 grid = iq2xs_grid[qs & 511];
565+ const vec4 grid0 = vec4(unpack8(grid.x));
566+ const vec4 grid1 = vec4(unpack8(grid.y));
567+
568+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
569+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
570+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
571+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
572+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
573+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
574+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
575+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
577576#elif defined(DATA_A_IQ2_S)
578577 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
579578 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
580579
581- const uint ib = idx / 128 ; // 2 values per idx
582- const uint ib8 = ( idx % 128) / 4 ; // 0..31
583- const uint ib32 = ib8 / 4; // 0..7
580+ const uint ib = idx / 32 ; // 8 values per idx
581+ const uint ib8 = idx % 32 ; // 0..31
582+ const uint ib32 = ib8 / 4; // 0..7
584583
585584 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
586585 const uint qs = data_a[ib].qs[ib8];
587586 const uint qh = data_a[ib].qh[ib32];
588587 const uint qhshift = 2 * (ib8 % 4);
589- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)) ;
588+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
590589
591590 const float d = float(data_a[ib].d);
592- const float db = d * 0.25 * (0.5 + scale);
593- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
594- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
595- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
596-
597- buf_a[buf_idx ] = BUF_TYPE(v.x);
598- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
591+ const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale));
592+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
593+ const vec4 grid0 = vec4(unpack8(grid.x));
594+ const vec4 grid1 = vec4(unpack8(grid.y));
595+
596+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
597+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
598+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
599+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
600+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
601+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
602+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
603+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
599604#elif defined(DATA_A_IQ3_XXS)
600605 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
601606 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
602607
603- const uint ib = idx / 128 ; // 2 values per idx
604- const uint iqs = ( idx % 128) / 2 ; // 0..63
608+ const uint ib = idx / 64 ; // 4 values per idx
609+ const uint iqs = idx % 64 ; // 0..63
605610 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
606611
607612 const float d = float(data_a[ib].d);
@@ -614,33 +619,35 @@ void main() {
614619 ));
615620 const float db = d * 0.5 * (0.5 + (signs >> 28));
616621 const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
617- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
618- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
619- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
620- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
621-
622- buf_a[buf_idx ] = BUF_TYPE(v.x);
623- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
622+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
623+ const uint grid = iq3xxs_grid[qs];
624+ const vec4 v = db * vec4(unpack8(grid));
625+
626+ buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
627+ buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
628+ buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
629+ buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
624630#elif defined(DATA_A_IQ3_S)
625631 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
626632 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
627633
628- const uint ib = idx / 128 ; // 2 values per idx
629- const uint iqs = ( idx % 128) / 2 ; // 0..63
634+ const uint ib = idx / 64 ; // 4 values per idx
635+ const uint iqs = idx % 64 ; // 0..63
630636 const uint iqh = iqs / 8;
631637
632638 const float d = float(data_a[ib].d);
633639 const uint qs = data_a[ib].qs[iqs];
634640 const uint qh = data_a[ib].qh[iqh];
635- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4 )));
641+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2 )));
636642 const uint scale = data_a[ib].scales[iqs / 16];
637- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
638643 const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
639- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)) ;
640- const vec2 v = db * vec2(sign01) * vec2( unpack8(grid).xy );
644+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
645+ const vec4 v = db * vec4( unpack8(grid));
641646
642- buf_a[buf_idx ] = BUF_TYPE(v.x);
643- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
647+ buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
648+ buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
649+ buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
650+ buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
644651#elif defined(DATA_A_IQ4_XS)
645652 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
646653 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
0 commit comments