@@ -4609,8 +4609,8 @@ static __global__ void rope(
46094609
46104610template <typename  T, bool  has_pos>
46114611static  __global__  void  rope_neox (
4612-     const  T * x, T * dst, int  ncols, const  int32_t  * pos, float  freq_scale, int  p_delta_rows, float  freq_base,
4613-     float  ext_factor, float  attn_factor, rope_corr_dims corr_dims
4612+     const  T * x, T * dst, int  ncols, int  n_dims,  const  int32_t  * pos, float  freq_scale, int  p_delta_rows, float  freq_base,
4613+     float  ext_factor, float  attn_factor, rope_corr_dims corr_dims,  float  theta_scale,  float  inv_ndims 
46144614) {
46154615    const  int  col = 2 *(blockDim .y *blockIdx .y  + threadIdx .y );
46164616
@@ -4619,23 +4619,25 @@ static __global__ void rope_neox(
46194619    }
46204620
46214621    const  int  row = blockDim .x *blockIdx .x  + threadIdx .x ;
4622-     const  int  i = row*ncols + col/2 ;
4622+     const  int  ib = col / n_dims;
4623+     const  int  ic = col % n_dims;
4624+ 
4625+     const  int  i = row*ncols + ib*n_dims + ic/2 ;
46234626    const  int  i2 = row/p_delta_rows;
46244627
4625-     //  simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4626-     const  float  cur_rot = -float (col)/ncols;
4628+     float  cur_rot = inv_ndims * ic - ib;
46274629
46284630    const  int  p = has_pos ? pos[i2] : 0 ;
4629-     const  float  theta_base = p*powf (freq_base, cur_rot );
4631+     const  float  theta_base = p*freq_scale* powf (theta_scale, col/ 2 . 0f );
46304632
46314633    float  cos_theta, sin_theta;
46324634    rope_yarn (theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
46334635
46344636    const  float  x0 = x[i + 0 ];
4635-     const  float  x1 = x[i + ncols /2 ];
4637+     const  float  x1 = x[i + n_dims /2 ];
46364638
4637-     dst[i + 0 ]       = x0*cos_theta - x1*sin_theta;
4638-     dst[i + ncols /2 ] = x0*sin_theta + x1*cos_theta;
4639+     dst[i + 0 ]         = x0*cos_theta - x1*sin_theta;
4640+     dst[i + n_dims /2 ] = x0*sin_theta + x1*cos_theta;
46394641}
46404642
46414643static  __global__  void  rope_glm_f32 (
@@ -5738,20 +5740,26 @@ static void rope_cuda(
57385740
57395741template <typename  T>
57405742static  void  rope_neox_cuda (
5741-     const  T * x, T * dst, int  ncols, int  nrows, const  int32_t  * pos, float  freq_scale, int  p_delta_rows,
5743+     const  T * x, T * dst, int  ncols, int  n_dims,  int   nrows, const  int32_t  * pos, float  freq_scale, int  p_delta_rows,
57425744    float  freq_base, float  ext_factor, float  attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
57435745) {
57445746    GGML_ASSERT (ncols % 2  == 0 );
57455747    const  dim3  block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
57465748    const  int  num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
57475749    const  dim3  block_nums (nrows, num_blocks_x, 1 );
5750+ 
5751+     const  float  theta_scale = powf (freq_base, -2 .0f /n_dims);
5752+     const  float  inv_ndims = -1 .0f  / n_dims;
5753+ 
57485754    if  (pos == nullptr ) {
57495755        rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
5750-             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5756+             x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims,
5757+             theta_scale, inv_ndims
57515758        );
57525759    } else  {
57535760        rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
5754-             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5761+             x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims,
5762+             theta_scale, inv_ndims
57555763        );
57565764    }
57575765}
@@ -6706,15 +6714,14 @@ inline void ggml_cuda_op_rope(
67066714        GGML_ASSERT (false );
67076715        rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
67086716    } else  if  (is_neox) {
6709-         GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet"  );
67106717        if  (src0->type  == GGML_TYPE_F32) {
67116718            rope_neox_cuda (
6712-                 (const  float  *)src0_dd, (float  *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6719+                 (const  float  *)src0_dd, (float  *)dst_dd, ne00, n_dims,  nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67136720                attn_factor, corr_dims, main_stream
67146721            );
67156722        } else  if  (src0->type  == GGML_TYPE_F16) {
67166723            rope_neox_cuda (
6717-                 (const  half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6724+                 (const  half *)src0_dd, (half *)dst_dd, ne00, n_dims,  nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67186725                attn_factor, corr_dims, main_stream
67196726            );
67206727        } else  {
0 commit comments