@@ -171,6 +171,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
171171
172172 layout_name = 'layout' if F == sym else 'data_layout'
173173 layout = attrs [layout_name ]
174+ kh , kw = attrs .get_int_tuple ("kernel_size" )
174175
175176 dtype = data .dtype
176177 out_dtype = dtype if out_dtype in ("same" , "" ) else out_dtype
@@ -191,6 +192,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
191192 if is_depthwise else \
192193 autotvm .task .args_to_workload (
193194 [data , kernel , strides , padding , dilation , layout , out_dtype ], conv2d )
195+ if is_depthwise :
196+ return None
194197 cfg = dispatch_ctx .query (target , workload )
195198 if cfg .is_fallback :
196199 _get_default_config (cfg , data , kernel , strides , padding , out_dtype , is_depthwise )
@@ -202,19 +205,17 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
202205
203206 new_data = tvm .placeholder ((batch_size , in_channel // ic_bn , height , width , ic_bn ),
204207 dtype = data .dtype )
205- if is_depthwise :
206- raise RuntimeError ("Intel graphics not supported depthwise schedule" )
207- else :
208- out_channel , _ , kh , kw = get_const_tuple (kernel .shape )
209- # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
210- new_attrs ['kernel_layout' ] = 'OIHW%di%do' % (ic_bn , oc_bn )
211208
212- # Store altered operator's config
213- new_kernel = tvm .placeholder ((out_channel // oc_bn , in_channel // ic_bn , kh , kw , ic_bn , oc_bn ),
214- dtype = kernel .dtype )
215- new_workload = autotvm .task .args_to_workload (
216- [new_data , new_kernel , strides , padding , dilation , new_attrs [layout_name ],
217- new_attrs ['out_layout' ], out_dtype ], conv2d_NCHWc )
209+ out_channel , _ , kh , kw = get_const_tuple (kernel .shape )
210+ # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
211+ new_attrs ['kernel_layout' ] = 'OIHW%di%do' % (ic_bn , oc_bn )
212+
213+ # Store altered operator's config
214+ new_kernel = tvm .placeholder ((out_channel // oc_bn , in_channel // ic_bn , kh , kw , ic_bn , oc_bn ),
215+ dtype = kernel .dtype )
216+ new_workload = autotvm .task .args_to_workload (
217+ [new_data , new_kernel , strides , padding , dilation , new_attrs [layout_name ],
218+ new_attrs ['out_layout' ], out_dtype ], conv2d_NCHWc )
218219
219220 dispatch_ctx .update (target , new_workload , cfg )
220221 if F == sym :
0 commit comments