2020import tvm
2121from tvm import autotvm
2222from tvm .autotvm .task .space import SplitEntity
23+ from tvm .contrib import cblas
2324
2425from .util import get_fp32_len
2526from .. import generic , tag , nn
@@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
4041# Declare dense compute with packing weight into cache-friendly layout
4142@autotvm .register_topi_compute (nn .dense , "cpu" , "direct_pack" )
4243def _declaration_dense_pack (cfg , data , weight , bias = None , out_dtype = None ):
43- if out_dtype is None :
44- out_dtype = data .dtype
45- batch , in_dim = get_const_tuple (data .shape )
46- out_dim , _ = get_const_tuple (weight .shape )
47- # create tuning space
48- cfg .define_split ("tile_y" , batch , num_outputs = 3 )
49- cfg .define_split ("tile_x" , out_dim , num_outputs = 3 )
50- cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
51- if cfg .is_fallback :
52- _default_dense_pack_config (cfg , batch , out_dim , in_dim )
53-
54- packw_bn = cfg ["tile_x" ].size [- 1 ]
55- packw_shape = (out_dim // packw_bn , in_dim , packw_bn )
56- packw = tvm .compute (packw_shape ,
57- lambda z , y , x : weight [z * packw_bn + x , y ], name = "packed_weight" )
58-
59- k = tvm .reduce_axis ((0 , in_dim ), name = "k" )
60- C = tvm .compute ((batch , out_dim ),
61- lambda y , x : tvm .sum (
62- data [y , k ].astype (out_dtype ) *
63- packw [x // packw_bn , k , x % packw_bn ].astype (out_dtype ),
64- axis = k ),
65- tag = "dense_pack" )
44+ target = tvm .target .current_target ()
45+ if "cblas" in target .libs :
46+ C = cblas .matmul (data , weight , False , True )
47+ else :
48+ if out_dtype is None :
49+ out_dtype = data .dtype
50+ batch , in_dim = get_const_tuple (data .shape )
51+ out_dim , _ = get_const_tuple (weight .shape )
52+ # create tuning space
53+ cfg .define_split ("tile_y" , batch , num_outputs = 3 )
54+ cfg .define_split ("tile_x" , out_dim , num_outputs = 3 )
55+ cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
56+ if cfg .is_fallback :
57+ _default_dense_pack_config (cfg , batch , out_dim , in_dim )
58+
59+ packw_bn = cfg ["tile_x" ].size [- 1 ]
60+ packw_shape = (out_dim // packw_bn , in_dim , packw_bn )
61+ packw = tvm .compute (packw_shape ,
62+ lambda z , y , x : weight [z * packw_bn + x , y ], name = "packed_weight" )
63+
64+ k = tvm .reduce_axis ((0 , in_dim ), name = "k" )
65+ C = tvm .compute ((batch , out_dim ),
66+ lambda y , x : tvm .sum (
67+ data [y , k ].astype (out_dtype ) *
68+ packw [x // packw_bn , k , x % packw_bn ].astype (out_dtype ),
69+ axis = k ),
70+ tag = "dense_pack" )
6671 if bias is not None :
6772 C = tvm .compute ((batch , out_dim ), lambda i , j : C [i , j ] + bias [j ].astype (out_dtype ),
6873 tag = tag .BROADCAST )
@@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
7277# Declare dense compute without packing weight
7378@autotvm .register_topi_compute (nn .dense , "cpu" , "direct_nopack" )
7479def _declaration_dense_nopack (cfg , data , weight , bias = None , out_dtype = None ):
75- if out_dtype is None :
76- out_dtype = data .dtype
77- batch , in_dim = get_const_tuple (data .shape )
78- out_dim , _ = get_const_tuple (weight .shape )
79- # create tuning space
80- cfg .define_split ("tile_x" , out_dim , num_outputs = 2 )
81- cfg .define_split ("tile_y" , batch , num_outputs = 2 )
82- cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
83- if cfg .is_fallback :
84- _default_dense_nopack_config (cfg , batch , out_dim , in_dim )
85-
86- vec = cfg ["tile_k" ].size [- 1 ]
87- k = tvm .reduce_axis ((0 , in_dim // vec ), "k" )
88- CC = tvm .compute ((batch , out_dim , vec ),
89- lambda z , y , x : tvm .sum (
90- data [z , k * vec + x ].astype (out_dtype ) *
91- weight [y , k * vec + x ].astype (out_dtype ), axis = k ))
92-
93- kk = tvm .reduce_axis ((0 , vec ), "kk" )
94- C = tvm .compute ((batch , out_dim ),
95- lambda y , x : tvm .sum (CC [y , x , kk ], axis = kk ),
96- tag = "dense_nopack" )
80+ target = tvm .target .current_target ()
81+ if "cblas" in target .libs :
82+ C = cblas .matmul (data , weight , False , True )
83+ else :
84+ if out_dtype is None :
85+ out_dtype = data .dtype
86+ batch , in_dim = get_const_tuple (data .shape )
87+ out_dim , _ = get_const_tuple (weight .shape )
88+ # create tuning space
89+ cfg .define_split ("tile_x" , out_dim , num_outputs = 2 )
90+ cfg .define_split ("tile_y" , batch , num_outputs = 2 )
91+ cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
92+ if cfg .is_fallback :
93+ _default_dense_nopack_config (cfg , batch , out_dim , in_dim )
94+
95+ vec = cfg ["tile_k" ].size [- 1 ]
96+ k = tvm .reduce_axis ((0 , in_dim // vec ), "k" )
97+ CC = tvm .compute ((batch , out_dim , vec ),
98+ lambda z , y , x : tvm .sum (
99+ data [z , k * vec + x ].astype (out_dtype ) *
100+ weight [y , k * vec + x ].astype (out_dtype ), axis = k ))
101+
102+ kk = tvm .reduce_axis ((0 , vec ), "kk" )
103+ C = tvm .compute ((batch , out_dim ),
104+ lambda y , x : tvm .sum (CC [y , x , kk ], axis = kk ),
105+ tag = "dense_nopack" )
97106 if bias is not None :
98107 C = tvm .compute ((batch , out_dim ), lambda i , j : C [i , j ] + bias [j ].astype (out_dtype ),
99108 tag = tag .BROADCAST )
@@ -116,6 +125,10 @@ def _callback(op):
116125
117126@autotvm .register_topi_schedule (generic .schedule_dense , "cpu" , "direct_pack" )
118127def _schedule_dense_pack (cfg , outs ):
128+ target = tvm .target .current_target ()
129+ if "cblas" in target .libs :
130+ return generic .schedule_extern (outs )
131+
119132 s = tvm .create_schedule ([x .op for x in outs ])
120133
121134 def _callback (op ):
@@ -127,6 +140,10 @@ def _callback(op):
127140
128141@autotvm .register_topi_schedule (generic .schedule_dense , "cpu" , "direct_nopack" )
129142def _schedule_dense_nopack (cfg , outs ):
143+ target = tvm .target .current_target ()
144+ if "cblas" in target .libs :
145+ return generic .schedule_extern (outs )
146+
130147 s = tvm .create_schedule ([x .op for x in outs ])
131148
132149 def _callback (op ):
0 commit comments