Skip to content

Commit 88f9bfd

Browse files
soiferjicemelon
authored andcommitted
[TOPI][CUDA] Support cuBLAS BatchMatMul (#3936)
* Support cuBLAS BatchMatMul * Add test and check target name
1 parent 1de52bb commit 88f9bfd

File tree

5 files changed

+176
-1
lines changed

5 files changed

+176
-1
lines changed

python/tvm/contrib/cublas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
4646
lambda ins, outs: _intrin.call_packed(
4747
"tvm.contrib.cublas.matmul",
4848
ins[0], ins[1], outs[0], transa, transb), name="C")
49+
50+
def batch_matmul(lhs, rhs, transa=False, transb=False):
51+
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
52+
53+
Parameters
54+
----------
55+
lhs : Tensor
56+
The left matrix operand
57+
rhs : Tensor
58+
The right matrix operand
59+
transa : bool
60+
Whether transpose lhs
61+
transb : bool
62+
Whether transpose rhs
63+
64+
Returns
65+
-------
66+
C : Tensor
67+
The result tensor.
68+
"""
69+
b = lhs.shape[0]
70+
n = lhs.shape[2] if transa else lhs.shape[1]
71+
m = rhs.shape[1] if transb else rhs.shape[2]
72+
return _api.extern(
73+
(b, n, m), [lhs, rhs],
74+
lambda ins, outs: _intrin.call_packed(
75+
"tvm.contrib.cublas.batch_matmul",
76+
ins[0], ins[1], outs[0], transa, transb), name="C")

src/contrib/cublas/cublas.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,50 @@ struct CublasDgemmOp {
8181
}
8282
};
8383

84+
struct CublasSgemmBatchOp {
85+
typedef float TDatatype;
86+
cublasHandle_t handle;
87+
explicit CublasSgemmBatchOp(cublasHandle_t hdl)
88+
: handle(hdl)
89+
{}
90+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
91+
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
92+
int c_stride, int ldc) {
93+
CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
94+
BooleanToTranspose(ta),
95+
BooleanToTranspose(tb),
96+
M, N, K,
97+
&alpha,
98+
A, lda, a_stride,
99+
B, ldb, b_stride,
100+
&beta,
101+
C, ldc, c_stride,
102+
batch_size));
103+
}
104+
};
105+
106+
struct CublasDgemmBatchOp {
107+
typedef double TDatatype;
108+
cublasHandle_t handle;
109+
explicit CublasDgemmBatchOp(cublasHandle_t hdl)
110+
: handle(hdl)
111+
{}
112+
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
113+
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
114+
int c_stride, int ldc) {
115+
CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
116+
BooleanToTranspose(ta),
117+
BooleanToTranspose(tb),
118+
M, N, K,
119+
&alpha,
120+
A, lda, a_stride,
121+
B, ldb, b_stride,
122+
&beta,
123+
C, ldc, c_stride,
124+
batch_size));
125+
}
126+
};
127+
84128
// matrix multiplication for row major
85129
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
86130
.set_body([](TVMArgs args, TVMRetValue *ret) {
@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
96140
else
97141
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
98142
});
143+
144+
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
145+
.set_body([](TVMArgs args, TVMRetValue* ret) {
146+
DLTensor* A = args[0];
147+
148+
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
149+
TypeMatch(A->dtype, kDLFloat, 64));
150+
151+
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
152+
153+
if (TypeMatch(A->dtype, kDLFloat, 32))
154+
CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
155+
else
156+
CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
157+
});
158+
99159
} // namespace contrib
100160
} // namespace tvm

tests/python/contrib/test_cublas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,34 @@ def verify(target="cuda"):
4444
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
4545
verify()
4646

47+
def test_batch_matmul():
48+
j = 16
49+
n = 1024
50+
l = 128
51+
m = 235
52+
A = tvm.placeholder((j, n, l), name='A')
53+
B = tvm.placeholder((j, l, m), name='B')
54+
C = cublas.batch_matmul(A, B)
55+
s = tvm.create_schedule(C.op)
56+
57+
def verify(target="cuda"):
58+
if not tvm.module.enabled(target):
59+
print("skip because %s is not enabled..." % target)
60+
return
61+
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
62+
print("skip because extern function is not available")
63+
return
64+
ctx = tvm.gpu(0)
65+
f = tvm.build(s, [A, B, C], target)
66+
a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
67+
b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
68+
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
69+
f(a, b, c)
70+
tvm.testing.assert_allclose(
71+
c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
72+
verify()
73+
4774

4875
if __name__ == "__main__":
4976
test_matmul_add()
77+
test_batch_matmul()

topi/include/topi/contrib/cublas.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
6161
}, "C", "", {})[0];
6262
}
6363

64+
/*!
65+
* \brief Create an op that multiplies batch matrices
66+
* lhs and rhs with cuBLAS
67+
*
68+
* \param lhs The left matrix operand
69+
* \param rhs The right matrix operand
70+
* \param transa Whether to transpose lhs
71+
* \param transb Whether to transpose rhs
72+
*
73+
* \return The output tensor
74+
*/
75+
inline Tensor cublas_batch_matmul(const Tensor& lhs,
76+
const Tensor& rhs,
77+
bool transa,
78+
bool transb) {
79+
auto b = lhs->shape[0];
80+
auto n = transa ? lhs->shape[2] : lhs->shape[1];
81+
auto m = transb ? rhs->shape[1] : rhs->shape[2];
82+
83+
return make_extern(
84+
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
85+
[&](Array<Buffer> ins, Array<Buffer> outs) {
86+
return call_packed({
87+
Expr("tvm.contrib.cublas.batch_matmul"),
88+
pack_buffer(ins[0]),
89+
pack_buffer(ins[1]),
90+
pack_buffer(outs[0]),
91+
transa,
92+
transb });
93+
}, "C", "", {})[0];
94+
}
95+
6496
} // namespace contrib
6597
} // namespace topi
6698

topi/python/topi/cuda/batch_matmul.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,33 @@
1818
"""cuda batch_matmul operators"""
1919
from __future__ import absolute_import as _abs
2020
import tvm
21-
21+
from tvm.contrib import cublas
22+
from topi.nn import batch_matmul, batch_matmul_default
2223
from .. import generic
2324
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
2425

26+
@batch_matmul.register(["cuda", "gpu"])
27+
def batch_matmul_cuda(x, y):
28+
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
29+
data in batch.
30+
31+
Parameters
32+
----------
33+
x : tvm.Tensor
34+
3-D with shape [batch, M, K]
35+
36+
y : tvm.Tensor
37+
3-D with shape [batch, N, K]
38+
39+
Returns
40+
-------
41+
output : tvm.Tensor
42+
3-D with shape [batch, M, N]
43+
"""
44+
target = tvm.target.current_target()
45+
if target.target_name == "cuda" and "cublas" in target.libs:
46+
return cublas.batch_matmul(x, y, False, True)
47+
return batch_matmul_default(x, y)
2548

2649
@generic.schedule_batch_matmul.register(["cuda", "gpu"])
2750
def schedule_batch_matmul(outs):
@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
3861
s: Schedule
3962
The computation schedule for the op.
4063
"""
64+
target = tvm.target.current_target()
65+
if target.target_name == "cuda" and "cublas" in target.libs:
66+
return generic.schedule_extern(outs)
67+
4168
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
4269
s = tvm.create_schedule([x.op for x in outs])
4370

0 commit comments

Comments
 (0)