@@ -1842,16 +1842,27 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18421842 int64_t s12 = nb12 / ts_src1;
18431843 int64_t s13 = nb13 / ts_src1;
18441844
1845- // convert src1 to fp16
1846- if (src1->type != GGML_TYPE_F16) {
1847- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1848- const int64_t ne_src1 = ggml_nelements (src1);
1849- src1_f16_alloc.alloc (ne_src1);
1850- GGML_ASSERT (to_fp16_cuda != nullptr );
1845+ const cuda_t * src0_ptr = nullptr ;
1846+ const cuda_t * src1_ptr = nullptr ;
1847+
1848+ ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1849+ ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
18511850
1852- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851+ // Handle src0
1852+ src0_ptr = (const cuda_t *) src0->data ;
1853+
1854+ // Handle src1 - convert if necessary
1855+ if (src1->type == src0_type) {
1856+ src1_ptr = (const cuda_t *) src1->data ;
1857+ } else {
1858+ // Convert src1 to target type using traits conversion functions
1859+ const int64_t ne_src1 = ggml_nelements (src1);
1860+ src1_alloc.alloc (ne_src1);
18531861
1854- src1_f16 = src1_f16_alloc.get ();
1862+ const auto convert_func = traits::get_nc_converter (src1->type );
1863+ GGML_ASSERT (convert_func != nullptr );
1864+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865+ src1_ptr = src1_alloc.get ();
18551866 s11 = ne10;
18561867 s12 = ne11*s11;
18571868 s13 = ne12*s12;
@@ -1948,11 +1959,29 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19481959 cu_compute_type,
19491960 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19501961 }
1951- #endif
19521962
1953- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1954- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1955- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
1963+ // Convert output back to F32 if needed
1964+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (traits::ggml_type_val);
1966+ to_fp32_cuda (dst_temp.get (), dst_ddf, ne_dst, main_stream);
1967+ }
1968+ }
1969+
1970+ static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972+
1973+ switch (src0->type ) {
1974+ case GGML_TYPE_F32:
1975+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976+ break ;
1977+ case GGML_TYPE_BF16:
1978+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979+ break ;
1980+ case GGML_TYPE_F16:
1981+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982+ break ;
1983+ default :
1984+ GGML_ABORT (" Unsupported type" );
19561985 }
19571986}
19581987
@@ -2004,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20042033 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20052034 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20062035
2036+ // TODO update for generic tensor parallelism
2037+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2038+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2040+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041+
20072042 if (!split && use_mul_mat_vec) {
20082043 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
20092044 // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
0 commit comments