@@ -4104,22 +4104,20 @@ static void llm_build_k_shift(
41044104       struct  ggml_cgraph  * graph,
41054105            llm_rope_type   type,
41064106                  int64_t    n_ctx,
4107-                   int        n_rot,
41084107                  float      freq_base,
41094108                  float      freq_scale,
41104109       const  llm_build_cb & cb) {
41114110    const  int64_t  n_layer       = hparams.n_layer ;
41124111    const  int64_t  n_head_kv     = hparams.n_head_kv ;
41134112    const  int64_t  n_embd_head_k = hparams.n_embd_head_k ;
41144113    const  int64_t  n_embd_k_gqa  = hparams.n_embd_k_gqa ();
4114+     const  int32_t  n_rot         = hparams.n_rot ;
41154115    const  int32_t  n_orig_ctx    = cparams.n_yarn_orig_ctx ;
41164116    const  float    ext_factor    = cparams.yarn_ext_factor ;
41174117    const  float    attn_factor   = cparams.yarn_attn_factor ;
41184118    const  float    beta_fast     = cparams.yarn_beta_fast ;
41194119    const  float    beta_slow     = cparams.yarn_beta_slow ;
41204120
4121-     GGML_ASSERT (n_embd_head_k % n_rot == 0 );
4122- 
41234121    struct  ggml_tensor  * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx);
41244122    cb (K_shift, " K_shift" 1 );
41254123
@@ -4523,7 +4521,7 @@ struct llm_build_context {
45234521
45244522        //  shift the entire K-cache if needed
45254523        if  (do_rope_shift) {
4526-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
4524+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
45274525        }
45284526
45294527        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -4561,14 +4559,14 @@ struct llm_build_context {
45614559
45624560                Qcur = ggml_rope_custom (
45634561                    ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
4564-                     n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4562+                     hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
45654563                    ext_factor, attn_factor, beta_fast, beta_slow
45664564                );
45674565                cb (Qcur, " Qcur" 
45684566
45694567                Kcur = ggml_rope_custom (
45704568                    ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4571-                     n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4569+                     hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
45724570                    ext_factor, attn_factor, beta_fast, beta_slow
45734571                );
45744572                cb (Kcur, " Kcur" 
@@ -4691,6 +4689,7 @@ struct llm_build_context {
46914689
46924690        const  int64_t  n_embd_head = hparams.n_embd_head_v ;
46934691        GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4692+         GGML_ASSERT (n_embd_head == hparams.n_rot );
46944693
46954694        struct  ggml_tensor  * cur;
46964695        struct  ggml_tensor  * inpL;
@@ -4708,7 +4707,7 @@ struct llm_build_context {
47084707
47094708        //  shift the entire K-cache if needed
47104709        if  (do_rope_shift) {
4711-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
4710+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
47124711        }
47134712
47144713        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -4734,12 +4733,12 @@ struct llm_build_context {
47344733                    case  MODEL_7B:
47354734                        Qcur = ggml_rope_custom (
47364735                            ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4737-                             n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4736+                             hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
47384737                            ext_factor, attn_factor, beta_fast, beta_slow
47394738                        );
47404739                        Kcur = ggml_rope_custom (
47414740                            ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4742-                             n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4741+                             hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
47434742                            ext_factor, attn_factor, beta_fast, beta_slow
47444743                        );
47454744                        break ;
@@ -4812,6 +4811,7 @@ struct llm_build_context {
48124811        const  int64_t  n_embd_head = hparams.n_embd_head_v ;
48134812        const  int64_t  n_embd_gqa  = hparams.n_embd_v_gqa ();
48144813        GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4814+         GGML_ASSERT (n_embd_head == hparams.n_rot );
48154815
48164816        struct  ggml_tensor  * cur;
48174817        struct  ggml_tensor  * inpL;
@@ -4829,7 +4829,7 @@ struct llm_build_context {
48294829
48304830        //  shift the entire K-cache if needed
48314831        if  (do_rope_shift) {
4832-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
4832+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
48334833        }
48344834
48354835        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -4870,13 +4870,13 @@ struct llm_build_context {
48704870
48714871                //  using mode = 2 for neox mode
48724872                Qcur = ggml_rope_custom (
4873-                     ctx0, Qcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
4873+                     ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
48744874                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
48754875                );
48764876                cb (Qcur, " Qcur" 
48774877
48784878                Kcur = ggml_rope_custom (
4879-                     ctx0, Kcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
4879+                     ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
48804880                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
48814881                );
48824882                cb (Kcur, " Kcur" 
@@ -5033,9 +5033,8 @@ struct llm_build_context {
50335033        struct  ggml_cgraph  * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
50345034
50355035        const  int64_t  n_embd_head = hparams.n_embd_head_v ;
5036-         GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5037- 
5038-         const  int64_t  n_rot = n_embd_head_k / 2 ;
5036+         GGML_ASSERT (n_embd_head   == hparams.n_embd_head_k );
5037+         GGML_ASSERT (n_embd_head/2  == hparams.n_rot );
50395038
50405039        struct  ggml_tensor  * cur;
50415040        struct  ggml_tensor  * inpL;
@@ -5052,7 +5051,7 @@ struct llm_build_context {
50525051        cb (KQ_mask, " KQ_mask" 1 );
50535052
50545053        if  (do_rope_shift) {
5055-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
5054+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
50565055        }
50575056
50585057        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -5112,15 +5111,15 @@ struct llm_build_context {
51125111
51135112                //  RoPE the first n_rot of q/k, pass the other half, and concat.
51145113                struct  ggml_tensor  * qrot = ggml_view_3d (
5115-                         ctx0, tmpq, n_rot, n_head, n_tokens,
5114+                         ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
51165115                        ggml_element_size (tmpq) * n_embd_head,
51175116                        ggml_element_size (tmpq) * n_embd_head * n_head,
51185117                        0 
51195118                        );
51205119                cb (qrot, " qrot" 
51215120
51225121                struct  ggml_tensor  * krot = ggml_view_3d (
5123-                         ctx0, tmpk, n_rot, n_head, n_tokens,
5122+                         ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
51245123                        ggml_element_size (tmpk) * n_embd_head,
51255124                        ggml_element_size (tmpk) * n_embd_head * n_head,
51265125                        0 
@@ -5129,29 +5128,29 @@ struct llm_build_context {
51295128
51305129                //  get the second half of tmpq, e.g tmpq[n_rot:, :, :]
51315130                struct  ggml_tensor  * qpass = ggml_view_3d (
5132-                         ctx0, tmpq, n_rot, n_head, n_tokens,
5131+                         ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
51335132                        ggml_element_size (tmpq) * n_embd_head,
51345133                        ggml_element_size (tmpq) * n_embd_head * n_head,
5135-                         ggml_element_size (tmpq) * n_rot
5134+                         ggml_element_size (tmpq) * hparams. n_rot 
51365135                        );
51375136                cb (qpass, " qpass" 
51385137
51395138                struct  ggml_tensor  * kpass = ggml_view_3d (
5140-                         ctx0, tmpk, n_rot, n_head, n_tokens,
5139+                         ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
51415140                        ggml_element_size (tmpk) * n_embd_head,
51425141                        ggml_element_size (tmpk) * n_embd_head * n_head,
5143-                         ggml_element_size (tmpk) * n_rot
5142+                         ggml_element_size (tmpk) * hparams. n_rot 
51445143                        );
51455144                cb (kpass, " kpass" 
51465145
51475146                struct  ggml_tensor  * qrotated = ggml_rope_custom (
5148-                     ctx0, qrot, inp_pos, n_rot, 2 , 0 , n_orig_ctx,
5147+                     ctx0, qrot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
51495148                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
51505149                );
51515150                cb (qrotated, " qrotated" 
51525151
51535152                struct  ggml_tensor  * krotated = ggml_rope_custom (
5154-                     ctx0, krot, inp_pos, n_rot, 2 , 0 , n_orig_ctx,
5153+                     ctx0, krot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
51555154                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
51565155                );
51575156                cb (krotated, " krotated" 
@@ -5531,6 +5530,7 @@ struct llm_build_context {
55315530
55325531        const  int64_t  n_embd_head = hparams.n_embd_head_v ;
55335532        GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5533+         GGML_ASSERT (n_embd_head == hparams.n_rot );
55345534
55355535        struct  ggml_tensor  * cur;
55365536        struct  ggml_tensor  * inpL;
@@ -5548,7 +5548,7 @@ struct llm_build_context {
55485548
55495549        //  shift the entire K-cache if needed
55505550        if  (do_rope_shift) {
5551-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams. n_rot ,  freq_base, freq_scale, cb);
5551+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
55525552        }
55535553
55545554        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -5661,7 +5661,7 @@ struct llm_build_context {
56615661
56625662        //  shift the entire K-cache if needed
56635663        if  (do_rope_shift) {
5664-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
5664+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
56655665        }
56665666
56675667        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -5693,13 +5693,13 @@ struct llm_build_context {
56935693
56945694                //  using mode = 2 for neox mode
56955695                Qcur = ggml_rope_custom (
5696-                     ctx0, Qcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
5696+                     ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
56975697                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
56985698                );
56995699                cb (Qcur, " Qcur" 
57005700
57015701                Kcur = ggml_rope_custom (
5702-                     ctx0, Kcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
5702+                     ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
57035703                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
57045704                );
57055705                cb (Kcur, " Kcur" 
@@ -5778,7 +5778,7 @@ struct llm_build_context {
57785778
57795779        //  shift the entire K-cache if needed
57805780        if  (do_rope_shift) {
5781-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
5781+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
57825782        }
57835783
57845784        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -5874,6 +5874,7 @@ struct llm_build_context {
58745874
58755875        const  int64_t  n_embd_head = hparams.n_embd_head_v ;
58765876        GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5877+         GGML_ASSERT (n_embd_head == hparams.n_rot );
58775878
58785879        struct  ggml_tensor  * cur;
58795880        struct  ggml_tensor  * inpL;
@@ -5891,7 +5892,7 @@ struct llm_build_context {
58915892
58925893        //  shift the entire K-cache if needed
58935894        if  (do_rope_shift) {
5894-             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head,  freq_base, freq_scale, cb);
5895+             llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
58955896        }
58965897
58975898        for  (int  il = 0 ; il < n_layer; ++il) {
@@ -5917,13 +5918,13 @@ struct llm_build_context {
59175918                cb (Vcur, " Vcur" 
59185919
59195920                Qcur = ggml_rope_custom (
5920-                         ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head , n_head,    n_tokens), inp_pos,
5921+                         ctx0, ggml_reshape_3d (ctx0, Qcur, hparams. n_rot , n_head,    n_tokens), inp_pos,
59215922                        n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
59225923                        ext_factor, attn_factor, beta_fast, beta_slow);
59235924                cb (Qcur, " Qcur" 
59245925
59255926                Kcur = ggml_rope_custom (
5926-                         ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head , n_head_kv, n_tokens), inp_pos,
5927+                         ctx0, ggml_reshape_3d (ctx0, Kcur, hparams. n_rot , n_head_kv, n_tokens), inp_pos,
59275928                        n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
59285929                        ext_factor, attn_factor, beta_fast, beta_slow);
59295930                cb (Kcur, " Kcur" 
0 commit comments