@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443#define CUDA_SCALE_BLOCK_SIZE 256
444444#define CUDA_CLAMP_BLOCK_SIZE 256
445445#define CUDA_ROPE_BLOCK_SIZE 256
446+ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
446447#define CUDA_ALIBI_BLOCK_SIZE 32
447448#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
501502
502503static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
503504
505+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
506+ #pragma unroll
507+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
508+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
509+ }
510+ return x;
511+ }
512+
513+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
514+ #pragma unroll
515+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
516+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
517+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
518+ }
519+ return a;
520+ }
521+
522+ static __device__ __forceinline__ float warp_reduce_max (float x) {
523+ #pragma unroll
524+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
525+ x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
526+ }
527+ return x;
528+ }
529+
504530static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
505531 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
506532
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
577603 dst[i] = x[i] * x[i];
578604}
579605
580- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
581- #pragma unroll
582- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
583- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
584- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
585- }
586- return a;
587- }
588-
589606template <int block_size>
590607static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
591608 const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
624641 }
625642}
626643
627- static __device__ __forceinline__ float warp_reduce_sum (float x) {
628- #pragma unroll
629- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
630- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
631- }
632- return x;
633- }
634-
635644template <int block_size>
636645static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
637646 const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47174726 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47184727}
47194728
4720- // the CUDA soft max implementation differs from the CPU implementation
4721- // instead of doubles floats are used
4722- static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4723- const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4724- const int block_size = blockDim .y ;
4725- const int tid = threadIdx .y ;
4729+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4730+ const int tid = threadIdx .x ;
4731+ const int rowx = blockIdx .x ;
4732+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4733+
4734+ const int block_size = blockDim .x ;
4735+
4736+ const int warp_id = threadIdx .x / WARP_SIZE;
4737+ const int lane_id = threadIdx .x % WARP_SIZE;
4738+
4739+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
47264740
47274741 float max_val = -INFINITY;
47284742
47294743 for (int col = tid; col < ncols; col += block_size) {
4730- const int i = row*ncols + col;
4731- max_val = max (max_val, x[i]);
4744+ const int ix = rowx*ncols + col;
4745+ const int iy = rowy*ncols + col;
4746+ max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
47324747 }
47334748
47344749 // find the max value in the block
4735- #pragma unroll
4736- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4737- max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
4750+ max_val = warp_reduce_max (max_val);
4751+ if (block_size > WARP_SIZE) {
4752+ if (warp_id == 0 ) {
4753+ buf[lane_id] = -INFINITY;
4754+ }
4755+ __syncthreads ();
4756+
4757+ if (lane_id == 0 ) {
4758+ buf[warp_id] = max_val;
4759+ }
4760+ __syncthreads ();
4761+
4762+ max_val = buf[lane_id];
4763+ max_val = warp_reduce_max (max_val);
47384764 }
47394765
47404766 float tmp = 0 .f ;
47414767
47424768 for (int col = tid; col < ncols; col += block_size) {
4743- const int i = row*ncols + col;
4744- const float val = expf (x[i] - max_val);
4769+ const int ix = rowx*ncols + col;
4770+ const int iy = rowy*ncols + col;
4771+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
47454772 tmp += val;
4746- dst[i ] = val;
4773+ dst[ix ] = val;
47474774 }
47484775
4749- // sum up partial sums
4750- #pragma unroll
4751- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4752- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4776+ // find the sum of exps in the block
4777+ tmp = warp_reduce_sum (tmp);
4778+ if (block_size > WARP_SIZE) {
4779+ if (warp_id == 0 ) {
4780+ buf[lane_id] = 0 .f ;
4781+ }
4782+ __syncthreads ();
4783+
4784+ if (lane_id == 0 ) {
4785+ buf[warp_id] = tmp;
4786+ }
4787+ __syncthreads ();
4788+
4789+ tmp = buf[lane_id];
4790+ tmp = warp_reduce_sum (tmp);
47534791 }
47544792
47554793 const float inv_tmp = 1 .f / tmp;
47564794
47574795 for (int col = tid; col < ncols; col += block_size) {
4758- const int i = row *ncols + col;
4796+ const int i = rowx *ncols + col;
47594797 dst[i] *= inv_tmp;
47604798 }
47614799}
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57925830 diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
57935831}
57945832
5795- static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5796- const dim3 block_dims (1 , WARP_SIZE, 1 );
5833+ static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5834+ int nth = WARP_SIZE;
5835+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5836+ const dim3 block_dims (nth, 1 , 1 );
57975837 const dim3 block_nums (nrows_x, 1 , 1 );
5798- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5838+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
57995839}
58005840
58015841static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
68466886 GGML_ASSERT (src0->type == GGML_TYPE_F32);
68476887 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
68486888
6889+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6890+
68496891 const int64_t ne00 = src0->ne [0 ];
6850- const int64_t nrows = ggml_nrows (src0);
6892+ const int64_t nrows_x = ggml_nrows (src0);
6893+ const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
68516894
6852- soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows, main_stream);
6895+ float scale = 1 .0f ;
6896+ memcpy (&scale, dst->op_params , sizeof (float ));
6897+
6898+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
68536899
6854- (void ) src1;
68556900 (void ) dst;
6856- (void ) src1_dd;
68576901}
68586902
68596903inline void ggml_cuda_op_scale (
0 commit comments