Skip to content

Commit 2789011

Browse files
committed
fix ci
1 parent 9f4df30 commit 2789011

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

topi/python/topi/intel_graphics/conv2d.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
171171

172172
layout_name = 'layout' if F == sym else 'data_layout'
173173
layout = attrs[layout_name]
174+
kh, kw = attrs.get_int_tuple("kernel_size")
174175

175176
dtype = data.dtype
176177
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
@@ -191,6 +192,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
191192
if is_depthwise else \
192193
autotvm.task.args_to_workload(
193194
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
195+
if is_depthwise:
196+
return None
194197
cfg = dispatch_ctx.query(target, workload)
195198
if cfg.is_fallback:
196199
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
@@ -202,19 +205,17 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
202205

203206
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
204207
dtype=data.dtype)
205-
if is_depthwise:
206-
raise RuntimeError("Intel graphics not supported depthwise schedule")
207-
else:
208-
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
209-
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
210-
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
211208

212-
# Store altered operator's config
213-
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
214-
dtype=kernel.dtype)
215-
new_workload = autotvm.task.args_to_workload(
216-
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
217-
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
209+
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
210+
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
211+
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
212+
213+
# Store altered operator's config
214+
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
215+
dtype=kernel.dtype)
216+
new_workload = autotvm.task.args_to_workload(
217+
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
218+
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
218219

219220
dispatch_ctx.update(target, new_workload, cfg)
220221
if F == sym:

0 commit comments

Comments
 (0)