2020import  tvm 
2121from  tvm  import  autotvm 
2222from  tvm .autotvm .task .space  import  SplitEntity 
23+ from  tvm .contrib  import  cblas 
2324
2425from  .util  import  get_fp32_len 
2526from  .. import  generic , tag , nn 
@@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
4041# Declare dense compute with packing weight into cache-friendly layout 
4142@autotvm .register_topi_compute (nn .dense , "cpu" , "direct_pack" ) 
4243def  _declaration_dense_pack (cfg , data , weight , bias = None , out_dtype = None ):
43-     if  out_dtype  is  None :
44-         out_dtype  =  data .dtype 
45-     batch , in_dim  =  get_const_tuple (data .shape )
46-     out_dim , _  =  get_const_tuple (weight .shape )
47-     # create tuning space 
48-     cfg .define_split ("tile_y" , batch , num_outputs = 3 )
49-     cfg .define_split ("tile_x" , out_dim , num_outputs = 3 )
50-     cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
51-     if  cfg .is_fallback :
52-         _default_dense_pack_config (cfg , batch , out_dim , in_dim )
53- 
54-     packw_bn  =  cfg ["tile_x" ].size [- 1 ]
55-     packw_shape  =  (out_dim  //  packw_bn , in_dim , packw_bn )
56-     packw  =  tvm .compute (packw_shape ,
57-                         lambda  z , y , x : weight [z  *  packw_bn  +  x , y ], name = "packed_weight" )
58- 
59-     k  =  tvm .reduce_axis ((0 , in_dim ), name = "k" )
60-     C  =  tvm .compute ((batch , out_dim ),
61-                     lambda  y , x : tvm .sum (
62-                         data [y , k ].astype (out_dtype ) * 
63-                         packw [x  //  packw_bn , k , x  %  packw_bn ].astype (out_dtype ),
64-                         axis = k ),
65-                     tag = "dense_pack" )
44+     target  =  tvm .target .current_target ()
45+     if  "cblas"  in  target .libs :
46+         C  =  cblas .matmul (data , weight , False , True )
47+     else :
48+         if  out_dtype  is  None :
49+             out_dtype  =  data .dtype 
50+         batch , in_dim  =  get_const_tuple (data .shape )
51+         out_dim , _  =  get_const_tuple (weight .shape )
52+         # create tuning space 
53+         cfg .define_split ("tile_y" , batch , num_outputs = 3 )
54+         cfg .define_split ("tile_x" , out_dim , num_outputs = 3 )
55+         cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
56+         if  cfg .is_fallback :
57+             _default_dense_pack_config (cfg , batch , out_dim , in_dim )
58+ 
59+         packw_bn  =  cfg ["tile_x" ].size [- 1 ]
60+         packw_shape  =  (out_dim  //  packw_bn , in_dim , packw_bn )
61+         packw  =  tvm .compute (packw_shape ,
62+                             lambda  z , y , x : weight [z  *  packw_bn  +  x , y ], name = "packed_weight" )
63+ 
64+         k  =  tvm .reduce_axis ((0 , in_dim ), name = "k" )
65+         C  =  tvm .compute ((batch , out_dim ),
66+                         lambda  y , x : tvm .sum (
67+                             data [y , k ].astype (out_dtype ) * 
68+                             packw [x  //  packw_bn , k , x  %  packw_bn ].astype (out_dtype ),
69+                             axis = k ),
70+                         tag = "dense_pack" )
6671    if  bias  is  not None :
6772        C  =  tvm .compute ((batch , out_dim ), lambda  i , j : C [i , j ] +  bias [j ].astype (out_dtype ),
6873                        tag = tag .BROADCAST )
@@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
7277# Declare dense compute without packing weight 
7378@autotvm .register_topi_compute (nn .dense , "cpu" , "direct_nopack" ) 
7479def  _declaration_dense_nopack (cfg , data , weight , bias = None , out_dtype = None ):
75-     if  out_dtype  is  None :
76-         out_dtype  =  data .dtype 
77-     batch , in_dim  =  get_const_tuple (data .shape )
78-     out_dim , _  =  get_const_tuple (weight .shape )
79-     # create tuning space 
80-     cfg .define_split ("tile_x" , out_dim , num_outputs = 2 )
81-     cfg .define_split ("tile_y" , batch , num_outputs = 2 )
82-     cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
83-     if  cfg .is_fallback :
84-         _default_dense_nopack_config (cfg , batch , out_dim , in_dim )
85- 
86-     vec  =  cfg ["tile_k" ].size [- 1 ]
87-     k  =  tvm .reduce_axis ((0 , in_dim  //  vec ), "k" )
88-     CC  =  tvm .compute ((batch , out_dim , vec ),
89-                      lambda  z , y , x : tvm .sum (
90-                          data [z , k  *  vec  +  x ].astype (out_dtype ) * 
91-                          weight [y , k  *  vec  +  x ].astype (out_dtype ), axis = k ))
92- 
93-     kk  =  tvm .reduce_axis ((0 , vec ), "kk" )
94-     C  =  tvm .compute ((batch , out_dim ),
95-                     lambda  y , x : tvm .sum (CC [y , x , kk ], axis = kk ),
96-                     tag = "dense_nopack" )
80+     target  =  tvm .target .current_target ()
81+     if  "cblas"  in  target .libs :
82+         C  =  cblas .matmul (data , weight , False , True )
83+     else :
84+         if  out_dtype  is  None :
85+             out_dtype  =  data .dtype 
86+         batch , in_dim  =  get_const_tuple (data .shape )
87+         out_dim , _  =  get_const_tuple (weight .shape )
88+         # create tuning space 
89+         cfg .define_split ("tile_x" , out_dim , num_outputs = 2 )
90+         cfg .define_split ("tile_y" , batch , num_outputs = 2 )
91+         cfg .define_split ("tile_k" , in_dim , num_outputs = 2 )
92+         if  cfg .is_fallback :
93+             _default_dense_nopack_config (cfg , batch , out_dim , in_dim )
94+ 
95+         vec  =  cfg ["tile_k" ].size [- 1 ]
96+         k  =  tvm .reduce_axis ((0 , in_dim  //  vec ), "k" )
97+         CC  =  tvm .compute ((batch , out_dim , vec ),
98+                          lambda  z , y , x : tvm .sum (
99+                              data [z , k  *  vec  +  x ].astype (out_dtype ) * 
100+                              weight [y , k  *  vec  +  x ].astype (out_dtype ), axis = k ))
101+ 
102+         kk  =  tvm .reduce_axis ((0 , vec ), "kk" )
103+         C  =  tvm .compute ((batch , out_dim ),
104+                         lambda  y , x : tvm .sum (CC [y , x , kk ], axis = kk ),
105+                         tag = "dense_nopack" )
97106    if  bias  is  not None :
98107        C  =  tvm .compute ((batch , out_dim ), lambda  i , j : C [i , j ] +  bias [j ].astype (out_dtype ),
99108                        tag = tag .BROADCAST )
@@ -116,6 +125,10 @@ def _callback(op):
116125
117126@autotvm .register_topi_schedule (generic .schedule_dense , "cpu" , "direct_pack" ) 
118127def  _schedule_dense_pack (cfg , outs ):
128+     target  =  tvm .target .current_target ()
129+     if  "cblas"  in  target .libs :
130+         return  generic .schedule_extern (outs )
131+ 
119132    s  =  tvm .create_schedule ([x .op  for  x  in  outs ])
120133
121134    def  _callback (op ):
@@ -127,6 +140,10 @@ def _callback(op):
127140
128141@autotvm .register_topi_schedule (generic .schedule_dense , "cpu" , "direct_nopack" ) 
129142def  _schedule_dense_nopack (cfg , outs ):
143+     target  =  tvm .target .current_target ()
144+     if  "cblas"  in  target .libs :
145+         return  generic .schedule_extern (outs )
146+ 
130147    s  =  tvm .create_schedule ([x .op  for  x  in  outs ])
131148
132149    def  _callback (op ):
0 commit comments