@@ -81,6 +81,50 @@ struct CublasDgemmOp {
8181 }
8282};
8383
84+ struct CublasSgemmBatchOp {
85+ typedef float TDatatype;
86+ cublasHandle_t handle;
87+ explicit CublasSgemmBatchOp (cublasHandle_t hdl)
88+ : handle(hdl)
89+ {}
90+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float * A,
91+ int a_stride, int lda, float * B, int b_stride, int ldb, float beta, float * C,
92+ int c_stride, int ldc) {
93+ CHECK_CUBLAS_ERROR (cublasSgemmStridedBatched (handle,
94+ BooleanToTranspose (ta),
95+ BooleanToTranspose (tb),
96+ M, N, K,
97+ &alpha,
98+ A, lda, a_stride,
99+ B, ldb, b_stride,
100+ &beta,
101+ C, ldc, c_stride,
102+ batch_size));
103+ }
104+ };
105+
106+ struct CublasDgemmBatchOp {
107+ typedef double TDatatype;
108+ cublasHandle_t handle;
109+ explicit CublasDgemmBatchOp (cublasHandle_t hdl)
110+ : handle(hdl)
111+ {}
112+ void operator ()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double * A,
113+ int a_stride, int lda, double * B, int b_stride, int ldb, double beta, double * C,
114+ int c_stride, int ldc) {
115+ CHECK_CUBLAS_ERROR (cublasDgemmStridedBatched (handle,
116+ BooleanToTranspose (ta),
117+ BooleanToTranspose (tb),
118+ M, N, K,
119+ &alpha,
120+ A, lda, a_stride,
121+ B, ldb, b_stride,
122+ &beta,
123+ C, ldc, c_stride,
124+ batch_size));
125+ }
126+ };
127+
84128// matrix multiplication for row major
85129TVM_REGISTER_GLOBAL (" tvm.contrib.cublas.matmul" )
86130.set_body([](TVMArgs args, TVMRetValue *ret) {
@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
96140 else
97141 CallGemm (args, ret, CublasDgemmOp (entry_ptr->handle ));
98142});
143+
144+ TVM_REGISTER_GLOBAL (" tvm.contrib.cublas.batch_matmul" )
145+ .set_body([](TVMArgs args, TVMRetValue* ret) {
146+ DLTensor* A = args[0 ];
147+
148+ CHECK (TypeMatch (A->dtype , kDLFloat , 32 ) ||
149+ TypeMatch (A->dtype , kDLFloat , 64 ));
150+
151+ CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal ();
152+
153+ if (TypeMatch (A->dtype , kDLFloat , 32 ))
154+ CallBatchGemm (args, ret, CublasSgemmBatchOp (entry_ptr->handle ));
155+ else
156+ CallBatchGemm (args, ret, CublasDgemmBatchOp (entry_ptr->handle ));
157+ });
158+
99159} // namespace contrib
100160} // namespace tvm
0 commit comments