@@ -6696,8 +6696,10 @@ inline void ggml_cuda_op_clamp(
66966696    GGML_ASSERT (src0->type  == GGML_TYPE_F32);
66976697    GGML_ASSERT ( dst->type  == GGML_TYPE_F32);
66986698
6699-     const  float  min = ((float  *) dst->op_params )[0 ];
6700-     const  float  max = ((float  *) dst->op_params )[1 ];
6699+     float  min;
6700+     float  max;
6701+     memcpy (&min, dst->op_params , sizeof (float ));
6702+     memcpy (&max, (float  *) dst->op_params  + 1 , sizeof (float ));
67016703
67026704    clamp_f32_cuda (src0_dd, dst_dd, min, max, ggml_nelements (src0), main_stream);
67036705    CUDA_CHECK (cudaGetLastError ());
@@ -7221,6 +7223,30 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
72217223    ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
72227224}
72237225
7226+ __global__  void  k_compute_batched_ptrs (
7227+         const  half * src0_as_f16, const  half * src1_as_f16, half * dst_f16,
7228+         void  ** ptrs,
7229+         int  ne12, int  ne13,
7230+         int  ne23,
7231+         int  nb02, int  nb03,
7232+         int  nb12, int  nb13,
7233+         int  nb2, int  nb3,
7234+         int  r2, int  r3) {
7235+     int  i13 = blockIdx .x  * blockDim .x  + threadIdx .x ;
7236+     int  i12 = blockIdx .y  * blockDim .y  + threadIdx .y ;
7237+ 
7238+     if  (i13 >= ne13 || i12 >= ne12) {
7239+         return ;
7240+     }
7241+ 
7242+     int  i03 = i13 / r3;
7243+     int  i02 = i12 / r2;
7244+ 
7245+     ptrs[0 *ne23 + i12 + i13*ne12] = (char  *) src0_as_f16 + i02*nb02   + i03*nb03;
7246+     ptrs[1 *ne23 + i12 + i13*ne12] = (char  *) src1_as_f16 + i12*nb12/2  + i13*nb13/2 ;
7247+     ptrs[2 *ne23 + i12 + i13*ne12] = (char  *)     dst_f16 + i12* nb2/2  + i13* nb3/2 ;
7248+ }
7249+ 
72247250static  void  ggml_cuda_mul_mat_mat_batched_cublas (const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
72257251    GGML_ASSERT (!ggml_is_transposed (src0));
72267252    GGML_ASSERT (!ggml_is_transposed (src1));
@@ -7322,49 +7348,35 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73227348                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
73237349    } else  {
73247350        //  use cublasGemmBatchedEx
7325-         //  TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
73267351        const  int  ne23 = ne12*ne13;
73277352
7328-         //  TODO: avoid this alloc
7329-         void  ** ptrs = (void  **) malloc (3 *ne23*sizeof (void  *));
7330- 
7331-         for  (int  i13 = 0 ; i13 < ne13; ++i13) {
7332-             for  (int  i12 = 0 ; i12 < ne12; ++i12) {
7333-                 int  i03 = i13 / r3;
7334-                 int  i02 = i12 / r2;
7335- 
7336-                 ptrs[0 *ne23 + i12 + i13*ne12] = (char  *) src0_as_f16 + i02*src0->nb [2 ]   + i03*src0->nb [3 ];
7337-                 ptrs[1 *ne23 + i12 + i13*ne12] = (char  *) src1_as_f16 + i12*src1->nb [2 ]/2  + i13*src1->nb [3 ]/2 ;
7338-                 ptrs[2 *ne23 + i12 + i13*ne12] = (char  *)     dst_f16 + i12* dst->nb [2 ]/2  + i13* dst->nb [3 ]/2 ;
7339-             }
7340-         }
7341- 
7342-         //  allocate device memory for pointers
73437353        void  ** ptrs_as = nullptr ;
7344-         CUDA_CHECK (cudaMalloc (&ptrs_as, 3 *ne23*sizeof (void  *)));
7345- 
7346-         //  TODO: this does not work for some reason -- not sure why?
7347-         // size_t ptrs_s = 0;
7348-         // ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7349- 
7350-         //  copy pointers to device
7351-         CUDA_CHECK (cudaMemcpy (ptrs_as, ptrs, 3 *ne23*sizeof (void  *), cudaMemcpyHostToDevice));
7352- 
7353-         free (ptrs);
7354+         size_t  ptrs_s = 0 ;
7355+         ptrs_as = (void  **) ggml_cuda_pool_malloc (3 *ne23*sizeof (void  *), &ptrs_s);
7356+ 
7357+         dim3  block_dims (ne13, ne12);
7358+         k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
7359+                 src0_as_f16, src1_as_f16, dst_f16,
7360+                 ptrs_as,
7361+                 ne12, ne13,
7362+                 ne23,
7363+                 nb02, nb03,
7364+                 nb12, nb13,
7365+                 dst->nb [2 ], dst->nb [3 ],
7366+                 r2, r3);
7367+         CUDA_CHECK (cudaGetLastError ());
73547368
73557369        CUBLAS_CHECK (
73567370        cublasGemmBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73577371                ne01, ne11, ne10,
7358-                 &alpha_f16, (const  void  **) (ptrs_as + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
7359-                             (const  void  **) (ptrs_as + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
7360-                 &beta_f16,  (      void  **) (ptrs_as + 2 *ne23), CUDA_R_16F, ne01,
7372+                 &alpha_f16, (const  void  *  const   *) (ptrs_as + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
7373+                             (const  void  *  const   *) (ptrs_as + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
7374+                 &beta_f16,  (      void  **        ) (ptrs_as + 2 *ne23), CUDA_R_16F, ne01,
73617375                ne23,
73627376                CUBLAS_COMPUTE_16F,
73637377                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
73647378
7365-         //  free device memory for pointers
7366-         CUDA_CHECK (cudaFree (ptrs_as));
7367-         // ggml_cuda_pool_free(ptrs_as, ptrs_s);
7379+         ggml_cuda_pool_free (ptrs_as, ptrs_s);
73687380    }
73697381#endif 
73707382
0 commit comments