@@ -488,6 +488,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
488488 }
489489}
490490
491+ static cudaError_t ggml_cuda_h2d_tensor_2d_hack (void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) {
492+ const uint64_t ne0 = src->ne [0 ];
493+ const uint64_t ne1 = src->ne [1 ];
494+ const uint64_t nb0 = src->nb [0 ];
495+ const uint64_t nb1 = src->nb [1 ];
496+ const uint64_t nb2 = src->nb [2 ];
497+ const uint64_t nb3 = src->nb [3 ];
498+ const enum ggml_type type = src->type ;
499+ const size_t ts = ggml_type_size (type);
500+ const size_t bs = ggml_blck_size (type);
501+
502+ const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3);
503+ if (nb0 == ts && nb1 == ts*ne0/bs) {
504+ return cudaMemcpyAsync (dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
505+ } else if (nb0 == ts) {
506+ return cudaMemcpy2DAsync (dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
507+ } else {
508+ for (uint64_t i1 = 0 ; i1 < ne1; i1++) {
509+ const void * rx = (const void *) ((const char *) x + i1*nb1);
510+ void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
511+ // pretend the row is a matrix with cols=1
512+ cudaError_t r = cudaMemcpy2DAsync (rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
513+ if (r != cudaSuccess) return r;
514+ }
515+ return cudaSuccess;
516+ }
517+ }
518+
491519static void ggml_cuda_mul_mat_f32 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
492520 const int64_t ne00 = src0->ne [0 ];
493521 const int64_t ne01 = src0->ne [1 ];
@@ -695,13 +723,13 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
695723 CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
696724
697725 // copy src1 to device
698- CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Y, src1, i03, i02, cudaStream));
726+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d_hack (c_Y, src1, i03, i02, cudaStream, wdata ));
699727
700728 // wait for data
701729 CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
702730
703731 // compute
704- dequantize_mul_mat_q4_0_cuda (c_Q, wdata + i * QK8_0 , c_D, ne00, ne01, cudaStream);
732+ dequantize_mul_mat_q4_0_cuda (c_Q, c_Y , c_D, ne00, ne01, cudaStream);
705733 CUDA_CHECK (cudaGetLastError ());
706734
707735 } else {
0 commit comments