@@ -29,8 +29,7 @@ def conv1d_ncw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32
2929 return nn .conv1d_ncw (data , kernel , strides , padding , dilation , out_dtype )
3030
3131
32- @autotvm .register_topi_schedule ("conv1d_ncw.cuda" )
33- def schedule_conv1d_ncw (cfg , outs ):
32+ def _schedule_conv1d_ncw (cfg , outs ):
3433 """TOPI schedule callback of conv1d ncw for cuda gpu
3534
3635 Parameters
@@ -51,7 +50,7 @@ def schedule_conv1d_ncw(cfg, outs):
5150 s = te .create_schedule ([x .op for x in outs ])
5251
5352 def _callback (op ):
54- if op .tag == "conv1d_ncw" :
53+ if op .tag == "conv1d_ncw" or op . tag == "group_conv1d_ncw" :
5554 pad_data = op .input_tensors [0 ]
5655 kernel = op .input_tensors [1 ]
5756 conv = op .output (0 )
@@ -140,13 +139,27 @@ def _callback(op):
140139 return s
141140
142141
142+ @autotvm .register_topi_schedule ("conv1d_ncw.cuda" )
143+ def schedule_conv1d_ncw (cfg , outs ):
144+ return _schedule_conv1d_ncw (cfg , outs )
145+
146+
147+ @autotvm .register_topi_compute ("group_conv1d_ncw.cuda" )
148+ def group_conv1d_ncw (cfg , data , kernel , strides , padding , dilation , groups , out_dtype = "float32" ):
149+ return nn .group_conv1d_ncw (data , kernel , strides , padding , dilation , groups , out_dtype )
150+
151+
152+ @autotvm .register_topi_schedule ("group_conv1d_ncw.cuda" )
153+ def schedule_group_conv1d_ncw (cfg , outs ):
154+ return _schedule_conv1d_ncw (cfg , outs )
155+
156+
143157@autotvm .register_topi_compute ("conv1d_nwc.cuda" )
144158def conv1d_nwc (cfg , data , kernel , strides , padding , dilation , out_dtype = "float32" ):
145159 return nn .conv1d_nwc (data , kernel , strides , padding , dilation , out_dtype )
146160
147161
148- @autotvm .register_topi_schedule ("conv1d_nwc.cuda" )
149- def schedule_conv1d_nwc (cfg , outs ):
162+ def _schedule_conv1d_nwc (cfg , outs ):
150163 """TOPI schedule callback of conv1d nwc for cuda gpu
151164
152165 Parameters
@@ -167,7 +180,7 @@ def schedule_conv1d_nwc(cfg, outs):
167180 s = te .create_schedule ([x .op for x in outs ])
168181
169182 def _callback (op ):
170- if op .tag == "conv1d_nwc" :
183+ if op .tag == "conv1d_nwc" or op . tag == "group_conv1d_nwc" :
171184 pad_data = op .input_tensors [0 ]
172185 kernel = op .input_tensors [1 ]
173186 conv = op .output (0 )
@@ -254,3 +267,18 @@ def _callback(op):
254267 traverse_inline (s , outs [0 ].op , _callback )
255268
256269 return s
270+
271+
272+ @autotvm .register_topi_schedule ("conv1d_nwc.cuda" )
273+ def schedule_conv1d_nwc (cfg , outs ):
274+ return _schedule_conv1d_nwc (cfg , outs )
275+
276+
277+ @autotvm .register_topi_compute ("group_conv1d_nwc.cuda" )
278+ def group_conv1d_nwc (cfg , data , kernel , strides , padding , dilation , groups , out_dtype = "float32" ):
279+ return nn .group_conv1d_nwc (data , kernel , strides , padding , dilation , groups , out_dtype )
280+
281+
282+ @autotvm .register_topi_schedule ("group_conv1d_nwc.cuda" )
283+ def schedule_group_conv1d_nwc (cfg , outs ):
284+ return _schedule_conv1d_nwc (cfg , outs )
0 commit comments