@@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16(
1515 const char * __restrict__ K,
1616 const char * __restrict__ V,
1717 const char * __restrict__ mask,
18- float * __restrict__ dst,
19- half2 * __restrict__ dst_meta,
18+ float * __restrict__ dst,
19+ float2 * __restrict__ dst_meta,
2020 const float scale,
2121 const int ne00,
2222 const int ne01,
@@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16(
180180 if (parallel_blocks == 1 || tid != 0 ) {
181181 return ;
182182 }
183- dst_meta[ic*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_half2 (kqmax, kqsum);
183+ dst_meta[ic*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax, kqsum);
184184#else
185185 NO_DEVICE_CODE;
186186#endif // FP16_AVAILABLE
@@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16(
194194 const char * __restrict__ K,
195195 const char * __restrict__ V,
196196 const char * __restrict__ mask,
197- float * __restrict__ dst,
198- half2 * __restrict__ dst_meta,
197+ float * __restrict__ dst,
198+ float2 * __restrict__ dst_meta,
199199 const float scale,
200200 const int ne00,
201201 const int ne01,
@@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16(
555555 continue ;
556556 }
557557
558- half2 dst_meta_val;
558+ float2 dst_meta_val;
559559 if (std::is_same<KQ_acc_t, float >::value) {
560- reinterpret_cast <half&>( dst_meta_val.x ) = KQ_max_f[j0/nwarps];
560+ dst_meta_val.x = KQ_max_f[j0/nwarps];
561561 } else {
562- dst_meta_val = KQ_max_h2[j0/nwarps];
562+ dst_meta_val. x = __low2float ( KQ_max_h2[j0/nwarps]) ;
563563 }
564- reinterpret_cast <half&>( dst_meta_val.y ) = KQ_rowsum_j;
564+ dst_meta_val.y = KQ_rowsum_j;
565565 dst_meta[(ic0 + j_VKQ)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = dst_meta_val;
566566 }
567567#else
@@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16(
572572template <int D, int parallel_blocks> // D == head size
573573__launch_bounds__ (D, 1 )
574574static __global__ void flash_attn_combine_results(
575- const float * __restrict__ VKQ_parts,
576- const half2 * __restrict__ VKQ_meta,
575+ const float * __restrict__ VKQ_parts,
576+ const float2 * __restrict__ VKQ_meta,
577577 float * __restrict__ dst) {
578578#if FP16_AVAILABLE
579579 VKQ_parts += parallel_blocks*D * gridDim .y *blockIdx .x ;
@@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results(
583583 const int tid = threadIdx .x ;
584584 __builtin_assume (tid < D);
585585
586- __shared__ half2 meta[parallel_blocks];
587- if (tid < parallel_blocks) {
588- meta[threadIdx .x ] = VKQ_meta[blockIdx .y *parallel_blocks + tid];
586+ __shared__ float2 meta[parallel_blocks];
587+ if (tid < 2 * parallel_blocks) {
588+ (( float *) meta) [threadIdx .x ] = (( const float *) VKQ_meta) [blockIdx .y *( 2 * parallel_blocks) + tid];
589589 }
590590
591591 __syncthreads ();
592592
593- half kqmax = __low2half ( meta[0 ]) ;
593+ float kqmax = meta[0 ]. x ;
594594#pragma unroll
595595 for (int l = 1 ; l < parallel_blocks; ++l) {
596- kqmax = __hmax (kqmax, __low2half ( meta[l]) );
596+ kqmax = max (kqmax, meta[l]. x );
597597 }
598598
599599 float VKQ_numerator = 0 .0f ;
600600 float VKQ_denominator = 0 .0f ;
601601#pragma unroll
602602 for (int l = 0 ; l < parallel_blocks; ++l) {
603- const half diff = __low2half ( meta[l]) - kqmax;
604- float KQ_max_scale = hexp (diff);
605- const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half ( SOFTMAX_FTZ_THRESHOLD) );
603+ const float diff = meta[l]. x - kqmax;
604+ const float KQ_max_scale = expf (diff);
605+ const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
606606 *((uint *) &KQ_max_scale) &= ftz_mask;
607607
608608 VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .y *D + blockIdx .y *D + tid];
609- VKQ_denominator += KQ_max_scale * __high2float ( meta[l]) ;
609+ VKQ_denominator += KQ_max_scale * meta[l]. y ;
610610 }
611611
612612 dst[blockIdx .y *D + tid] = VKQ_numerator / VKQ_denominator;
@@ -643,8 +643,8 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
643643 const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
644644 ggml_cuda_pool & pool, cudaStream_t main_stream
645645) {
646- ggml_cuda_pool_alloc<float > dst_tmp (pool);
647- ggml_cuda_pool_alloc<half2 > dst_tmp_meta (pool);
646+ ggml_cuda_pool_alloc<float > dst_tmp (pool);
647+ ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
648648
649649 if (parallel_blocks > 1 ) {
650650 dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
@@ -694,8 +694,8 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
694694 const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
695695 ggml_cuda_pool & pool, cudaStream_t main_stream
696696) {
697- ggml_cuda_pool_alloc<float > dst_tmp (pool);
698- ggml_cuda_pool_alloc<half2 > dst_tmp_meta (pool);
697+ ggml_cuda_pool_alloc<float > dst_tmp (pool);
698+ ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
699699
700700 if (parallel_blocks > 1 ) {
701701 dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
0 commit comments