Skip to content

Commit 2d8ac1d

Browse files
masahimasa
andauthored
[MKL] Fix offloading of batch_matmul to MKL (#6752)
* fix mkl offloading of batch matmul * name fix and add doc * add doc for lib arg Co-authored-by: masa <[email protected]>
1 parent 1831c17 commit 2d8ac1d

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

python/tvm/relay/op/strategy/x86.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
377377
name="batch_matmul_cblas.x86",
378378
plevel=15,
379379
)
380+
if "mkl" in target.libs:
381+
strategy.add_implementation(
382+
wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
383+
wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
384+
name="batch_matmul_mkl.x86",
385+
plevel=15,
386+
)
380387
return strategy
381388

382389

python/tvm/topi/x86/batch_matmul.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tvm import te
2020
from tvm import autotvm
2121
from tvm.autotvm.task.space import SplitEntity
22-
from tvm.contrib import cblas
22+
from tvm.contrib import cblas, mkl
2323
from .. import generic
2424
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
2525

@@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K):
137137
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
138138

139139

140-
@autotvm.register_topi_compute("batch_matmul_cblas.x86")
141-
def batch_matmul_cblas(cfg, x, y, out_shape=None):
140+
def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
142141
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
143-
data in batch.
142+
data in batch, using one of BLAS libraries.
144143
145144
Parameters
146145
----------
@@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
152151
3-D with shape [batch, N, K]
153152
out_shape : tuple or None
154153
Shape of the output
154+
lib : A contrib module which implements batch_matmul funtion
155+
cblas and mkl are supported
155156
156157
Returns
157158
-------
@@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
168169
assert out_shape[1] == M, "got invalid output shape"
169170
assert out_shape[2] == N, "got invalid output shape"
170171
cfg.add_flop(XB * M * N * XK * 2)
171-
return cblas.batch_matmul(x, y, False, True)
172+
return lib.batch_matmul(x, y, False, True)
173+
174+
175+
@autotvm.register_topi_compute("batch_matmul_cblas.x86")
176+
def batch_matmul_cblas(cfg, x, y, out_shape=None):
177+
"""Compute batch_matmul using cblas"""
178+
return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)
172179

173180

174181
@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
175182
def schedule_batch_matmul_cblas(_, outs):
183+
"""Create schedule for batch_matmul_cblas"""
184+
return generic.schedule_extern(outs)
185+
186+
187+
@autotvm.register_topi_compute("batch_matmul_mkl.x86")
188+
def batch_matmul_mkl(cfg, x, y, out_shape=None):
189+
"""Compute batch_matmul using mkl"""
190+
return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)
191+
192+
193+
@autotvm.register_topi_schedule("batch_matmul_mkl.x86")
194+
def schedule_batch_matmul_mkl(_, outs):
195+
"""Create schedule for batch_matmul_mul"""
176196
return generic.schedule_extern(outs)

0 commit comments

Comments
 (0)