1919from tvm import te
2020from tvm import autotvm
2121from tvm .autotvm .task .space import SplitEntity
22- from tvm .contrib import cblas
22+ from tvm .contrib import cblas , mkl
2323from .. import generic
2424from ..util import traverse_inline , get_const_tuple , get_max_power2_factor
2525
@@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K):
137137 cfg ["tile_y" ] = SplitEntity ([M // y_bn , y_bn ])
138138
139139
140- @autotvm .register_topi_compute ("batch_matmul_cblas.x86" )
141- def batch_matmul_cblas (cfg , x , y , out_shape = None ):
140+ def batch_matmul_blas_common (cfg , x , y , out_shape , lib ):
142141 """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
143- data in batch.
142+ data in batch, using one of BLAS libraries .
144143
145144 Parameters
146145 ----------
@@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
152151 3-D with shape [batch, N, K]
153152 out_shape : tuple or None
154153 Shape of the output
154+ lib : A contrib module which implements batch_matmul funtion
155+ cblas and mkl are supported
155156
156157 Returns
157158 -------
@@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
168169 assert out_shape [1 ] == M , "got invalid output shape"
169170 assert out_shape [2 ] == N , "got invalid output shape"
170171 cfg .add_flop (XB * M * N * XK * 2 )
171- return cblas .batch_matmul (x , y , False , True )
172+ return lib .batch_matmul (x , y , False , True )
173+
174+
175+ @autotvm .register_topi_compute ("batch_matmul_cblas.x86" )
176+ def batch_matmul_cblas (cfg , x , y , out_shape = None ):
177+ """Compute batch_matmul using cblas"""
178+ return batch_matmul_blas_common (cfg , x , y , out_shape , cblas )
172179
173180
174181@autotvm .register_topi_schedule ("batch_matmul_cblas.x86" )
175182def schedule_batch_matmul_cblas (_ , outs ):
183+ """Create schedule for batch_matmul_cblas"""
184+ return generic .schedule_extern (outs )
185+
186+
187+ @autotvm .register_topi_compute ("batch_matmul_mkl.x86" )
188+ def batch_matmul_mkl (cfg , x , y , out_shape = None ):
189+ """Compute batch_matmul using mkl"""
190+ return batch_matmul_blas_common (cfg , x , y , out_shape , mkl )
191+
192+
193+ @autotvm .register_topi_schedule ("batch_matmul_mkl.x86" )
194+ def schedule_batch_matmul_mkl (_ , outs ):
195+ """Create schedule for batch_matmul_mul"""
176196 return generic .schedule_extern (outs )
0 commit comments