Skip to content

Commit c870261

Browse files
soiferjhlu1
authored andcommitted
[TOPI] Use cblas for dense and batch_matmul when "cblas" is in the target libraries (#3787)
* Support cblas library in dense * start to add support for generic batch_matmul compute * Add x86 override for batch_matmul * Fix linting * reset file * Fix typos * dummy change to re-trigger CI
1 parent 95f12e3 commit c870261

File tree

4 files changed

+114
-50
lines changed

4 files changed

+114
-50
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def schedule_dense(attrs, outputs, target):
7373
@reg.register_compute("nn.batch_matmul")
7474
def compute_batch_matmul(attrs, inputs, out_type, target):
7575
"""Compute definition of batch_matmul"""
76-
return [topi.nn.batch_matmul(inputs[0], inputs[1])]
76+
with target:
77+
return [topi.nn.batch_matmul(inputs[0], inputs[1])]
7778

7879

7980
@reg.register_schedule("nn.batch_matmul")

topi/python/topi/nn/batch_matmul.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import tvm
2121
from ..util import get_const_tuple
2222

23-
24-
def batch_matmul(x, y):
23+
def batch_matmul_default(x, y):
2524
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
2625
data in batch.
2726
@@ -30,7 +29,7 @@ def batch_matmul(x, y):
3029
x : tvm.Tensor
3130
3-D with shape [batch, M, K]
3231
33-
y : tvm.TEnsor
32+
y : tvm.Tensor
3433
3-D with shape [batch, N, K]
3534
3635
Returns
@@ -49,3 +48,23 @@ def batch_matmul(x, y):
4948
return tvm.compute((batch, M, N),
5049
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
5150
tag='batch_matmul')
51+
52+
@tvm.target.generic_func
53+
def batch_matmul(x, y):
54+
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
55+
data in batch.
56+
57+
Parameters
58+
----------
59+
x : tvm.Tensor
60+
3-D with shape [batch, M, K]
61+
62+
y : tvm.Tensor
63+
3-D with shape [batch, N, K]
64+
65+
Returns
66+
-------
67+
output : tvm.Tensor
68+
3-D with shape [batch, M, N]
69+
"""
70+
return batch_matmul_default(x, y)

topi/python/topi/x86/batch_matmul.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,33 @@
1818
"""x86 batch_matmul operators"""
1919
from __future__ import absolute_import as _abs
2020
import tvm
21-
21+
from tvm.contrib import cblas
22+
from topi.nn import batch_matmul, batch_matmul_default
2223
from .. import generic
2324
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
2425

26+
@batch_matmul.register(["cpu"])
27+
def batch_matmul_x86(x, y):
28+
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
29+
data in batch.
30+
31+
Parameters
32+
----------
33+
x : tvm.Tensor
34+
3-D with shape [batch, M, K]
35+
36+
y : tvm.Tensor
37+
3-D with shape [batch, N, K]
38+
39+
Returns
40+
-------
41+
output : tvm.Tensor
42+
3-D with shape [batch, M, N]
43+
"""
44+
target = tvm.target.current_target()
45+
if "cblas" in target.libs:
46+
return cblas.batch_matmul(x, y, False, True)
47+
return batch_matmul_default(x, y)
2548

2649
@generic.schedule_batch_matmul.register(["cpu"])
2750
def schedule_batch_matmul(outs):
@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
3861
sch: Schedule
3962
The computation schedule for the op.
4063
"""
64+
target = tvm.target.current_target()
65+
if "cblas" in target.libs:
66+
return generic.schedule_extern(outs)
67+
4168
s = tvm.create_schedule([x.op for x in outs])
4269

4370
def _callback(op):

topi/python/topi/x86/dense.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tvm
2121
from tvm import autotvm
2222
from tvm.autotvm.task.space import SplitEntity
23+
from tvm.contrib import cblas
2324

2425
from .util import get_fp32_len
2526
from .. import generic, tag, nn
@@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
4041
# Declare dense compute with packing weight into cache-friendly layout
4142
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
4243
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
43-
if out_dtype is None:
44-
out_dtype = data.dtype
45-
batch, in_dim = get_const_tuple(data.shape)
46-
out_dim, _ = get_const_tuple(weight.shape)
47-
# create tuning space
48-
cfg.define_split("tile_y", batch, num_outputs=3)
49-
cfg.define_split("tile_x", out_dim, num_outputs=3)
50-
cfg.define_split("tile_k", in_dim, num_outputs=2)
51-
if cfg.is_fallback:
52-
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
53-
54-
packw_bn = cfg["tile_x"].size[-1]
55-
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
56-
packw = tvm.compute(packw_shape,
57-
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
58-
59-
k = tvm.reduce_axis((0, in_dim), name="k")
60-
C = tvm.compute((batch, out_dim),
61-
lambda y, x: tvm.sum(
62-
data[y, k].astype(out_dtype) *
63-
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
64-
axis=k),
65-
tag="dense_pack")
44+
target = tvm.target.current_target()
45+
if "cblas" in target.libs:
46+
C = cblas.matmul(data, weight, False, True)
47+
else:
48+
if out_dtype is None:
49+
out_dtype = data.dtype
50+
batch, in_dim = get_const_tuple(data.shape)
51+
out_dim, _ = get_const_tuple(weight.shape)
52+
# create tuning space
53+
cfg.define_split("tile_y", batch, num_outputs=3)
54+
cfg.define_split("tile_x", out_dim, num_outputs=3)
55+
cfg.define_split("tile_k", in_dim, num_outputs=2)
56+
if cfg.is_fallback:
57+
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
58+
59+
packw_bn = cfg["tile_x"].size[-1]
60+
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
61+
packw = tvm.compute(packw_shape,
62+
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
63+
64+
k = tvm.reduce_axis((0, in_dim), name="k")
65+
C = tvm.compute((batch, out_dim),
66+
lambda y, x: tvm.sum(
67+
data[y, k].astype(out_dtype) *
68+
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
69+
axis=k),
70+
tag="dense_pack")
6671
if bias is not None:
6772
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
6873
tag=tag.BROADCAST)
@@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
7277
# Declare dense compute without packing weight
7378
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
7479
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
75-
if out_dtype is None:
76-
out_dtype = data.dtype
77-
batch, in_dim = get_const_tuple(data.shape)
78-
out_dim, _ = get_const_tuple(weight.shape)
79-
# create tuning space
80-
cfg.define_split("tile_x", out_dim, num_outputs=2)
81-
cfg.define_split("tile_y", batch, num_outputs=2)
82-
cfg.define_split("tile_k", in_dim, num_outputs=2)
83-
if cfg.is_fallback:
84-
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
85-
86-
vec = cfg["tile_k"].size[-1]
87-
k = tvm.reduce_axis((0, in_dim // vec), "k")
88-
CC = tvm.compute((batch, out_dim, vec),
89-
lambda z, y, x: tvm.sum(
90-
data[z, k * vec + x].astype(out_dtype) *
91-
weight[y, k * vec + x].astype(out_dtype), axis=k))
92-
93-
kk = tvm.reduce_axis((0, vec), "kk")
94-
C = tvm.compute((batch, out_dim),
95-
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
96-
tag="dense_nopack")
80+
target = tvm.target.current_target()
81+
if "cblas" in target.libs:
82+
C = cblas.matmul(data, weight, False, True)
83+
else:
84+
if out_dtype is None:
85+
out_dtype = data.dtype
86+
batch, in_dim = get_const_tuple(data.shape)
87+
out_dim, _ = get_const_tuple(weight.shape)
88+
# create tuning space
89+
cfg.define_split("tile_x", out_dim, num_outputs=2)
90+
cfg.define_split("tile_y", batch, num_outputs=2)
91+
cfg.define_split("tile_k", in_dim, num_outputs=2)
92+
if cfg.is_fallback:
93+
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
94+
95+
vec = cfg["tile_k"].size[-1]
96+
k = tvm.reduce_axis((0, in_dim // vec), "k")
97+
CC = tvm.compute((batch, out_dim, vec),
98+
lambda z, y, x: tvm.sum(
99+
data[z, k * vec + x].astype(out_dtype) *
100+
weight[y, k * vec + x].astype(out_dtype), axis=k))
101+
102+
kk = tvm.reduce_axis((0, vec), "kk")
103+
C = tvm.compute((batch, out_dim),
104+
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
105+
tag="dense_nopack")
97106
if bias is not None:
98107
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
99108
tag=tag.BROADCAST)
@@ -116,6 +125,10 @@ def _callback(op):
116125

117126
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
118127
def _schedule_dense_pack(cfg, outs):
128+
target = tvm.target.current_target()
129+
if "cblas" in target.libs:
130+
return generic.schedule_extern(outs)
131+
119132
s = tvm.create_schedule([x.op for x in outs])
120133

121134
def _callback(op):
@@ -127,6 +140,10 @@ def _callback(op):
127140

128141
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
129142
def _schedule_dense_nopack(cfg, outs):
143+
target = tvm.target.current_target()
144+
if "cblas" in target.libs:
145+
return generic.schedule_extern(outs)
146+
130147
s = tvm.create_schedule([x.op for x in outs])
131148

132149
def _callback(op):

0 commit comments

Comments
 (0)