Skip to content

Commit a8468da

Browse files
author
Tristan Konolige
committed
[TOPI] Support grouped conv1d
Generalize the conv2d compute statement to a generic convNd that supports any layout and groups. Replace some existing conv2d and conv1d compute statements with this generic compute. Also add a topi group_conv1d compute that uses the generic convNd compute. Existing schedules for conv1d work with group_conv1d, so they are reused.
1 parent 72d3efe commit a8468da

File tree

15 files changed

+478
-327
lines changed

15 files changed

+478
-327
lines changed

python/tvm/autotvm/task/topi_integration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,10 @@ def wrapper(outs, *args, **kwargs):
233233
"""wrapper function for topi schedule"""
234234
workload = get_workload(outs, task_name)
235235
if workload is None:
236-
raise RuntimeError("Cannot find workload in attribute of this schedule")
236+
raise RuntimeError(
237+
f"Cannot find TOPI workload {task_name}. "
238+
"Is it registered with `register_topi_compute`?"
239+
)
237240
tgt = Target.current()
238241
cfg = DispatchContext.current.query(tgt, workload)
239242
return topi_schedule(cfg, outs, *args, **kwargs)
@@ -253,7 +256,7 @@ def traverse(tensors):
253256
for t in tensors:
254257
op = t.op
255258
wkl = traverse(op.input_tensors)
256-
if wkl:
259+
if wkl is not None:
257260
return wkl
258261

259262
if "workload" in op.attrs:

python/tvm/relay/frontend/onnx.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -526,23 +526,6 @@ def _impl_v1(cls, inputs, attr, params):
526526
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
527527
attr.pop("auto_pad")
528528

529-
# Check if the requested convolution is a group conv1d, if so convert it to conv2d.
530-
# TODO(jwfromm) Remove once proper group_conv1d is supported.
531-
group_conv1d = False
532-
if dimension_picker("conv")(attr) == "conv1d" and attr.get("group") != 1:
533-
group_conv1d = True
534-
# Expand input from NCW to NCHW
535-
data = _op.expand_dims(data, axis=2)
536-
# Expand kernel from OIW to OIHW
537-
kernel = _op.expand_dims(kernel, axis=2)
538-
# Add new value to kernel_shape, strices, dilation, pads, if needed
539-
attr["kernel_shape"] = [1] + list(attr["kernel_shape"])
540-
if "strides" in attr:
541-
attr["strides"] = [1] + list(attr["strides"])
542-
if "dilations" in attr:
543-
attr["dilations"] = [1] + list(attr["dilations"])
544-
if "pads" in attr:
545-
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]
546529
attr["channels"] = kernel_shapes[0][0]
547530
out = AttrCvt(
548531
op_name=dimension_picker("conv"),
@@ -555,10 +538,6 @@ def _impl_v1(cls, inputs, attr, params):
555538
custom_check=dimension_constraint(),
556539
)([data, kernel], attr, params)
557540

558-
# If this was a group_conv1d, squish output back to NCW.
559-
if group_conv1d:
560-
out = _op.squeeze(out, axis=[2])
561-
562541
use_bias = len(inputs) == 3
563542
if use_bias:
564543
out = _op.nn.bias_add(out, inputs[2])

python/tvm/relay/op/strategy/cuda.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -689,20 +689,36 @@ def conv1d_strategy_cuda(attrs, inputs, out_type, target):
689689
if dilation[0] < 1:
690690
raise ValueError("dilation should be a positive value")
691691
strategy = _op.OpStrategy()
692-
if layout == "NCW":
693-
strategy.add_implementation(
694-
wrap_compute_conv1d(topi.cuda.conv1d_ncw),
695-
wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
696-
name="conv1d_ncw.cuda",
697-
)
698-
elif layout == "NWC":
699-
strategy.add_implementation(
700-
wrap_compute_conv1d(topi.cuda.conv1d_nwc),
701-
wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
702-
name="conv1d_nwc.cuda",
703-
)
692+
if attrs.groups == 1:
693+
if layout == "NCW":
694+
strategy.add_implementation(
695+
wrap_compute_conv1d(topi.cuda.conv1d_ncw),
696+
wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
697+
name="conv1d_ncw.cuda",
698+
)
699+
elif layout == "NWC":
700+
strategy.add_implementation(
701+
wrap_compute_conv1d(topi.cuda.conv1d_nwc),
702+
wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
703+
name="conv1d_nwc.cuda",
704+
)
705+
else:
706+
raise ValueError("Unsupported conv1d layout {}".format(layout))
704707
else:
705-
raise ValueError("Unsupported conv1d layout {}".format(layout))
708+
if layout == "NCW":
709+
strategy.add_implementation(
710+
wrap_compute_group_conv1d(topi.cuda.group_conv1d_ncw),
711+
wrap_topi_schedule(topi.cuda.schedule_group_conv1d_ncw),
712+
name="group_conv1d_ncw.cuda",
713+
)
714+
elif layout == "NWC":
715+
strategy.add_implementation(
716+
wrap_compute_group_conv1d(topi.cuda.group_conv1d_nwc),
717+
wrap_topi_schedule(topi.cuda.schedule_group_conv1d_nwc),
718+
name="group_conv1d_nwc.cuda",
719+
)
720+
else:
721+
raise ValueError("Unsupported conv1d layout {}".format(layout))
706722
return strategy
707723

708724

python/tvm/relay/op/strategy/generic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,49 @@ def conv1d_strategy(attrs, inputs, out_type, target):
637637
return strategy
638638

639639

640+
def wrap_compute_group_conv1d(topi_compute):
641+
"""wrap conv1d topi compute"""
642+
643+
def _compute_group_conv1d(attrs, inputs, out_type):
644+
"""Compute definition of conv1d"""
645+
strides = get_const_tuple(attrs.strides)
646+
padding = get_const_tuple(attrs.padding)
647+
dilation = get_const_tuple(attrs.dilation)
648+
out_dtype = attrs.out_dtype
649+
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
650+
return [
651+
topi_compute(inputs[0], inputs[1], strides, padding, dilation, attrs.groups, out_dtype)
652+
]
653+
654+
return _compute_group_conv1d
655+
656+
657+
@override_native_generic_func("group_conv1d_strategy")
658+
def group_conv1d_strategy(attrs, inputs, out_type, target):
659+
"""group_conv1d generic strategy"""
660+
logger.warning("group_conv1d is not optimized for this platform.")
661+
layout = attrs.data_layout
662+
dilation = get_const_tuple(attrs.dilation)
663+
if dilation[0] < 1:
664+
raise ValueError("dilation should be a positive value")
665+
strategy = _op.OpStrategy()
666+
if layout == "NCW":
667+
strategy.add_implementation(
668+
wrap_compute_conv1d(topi.nn.group_conv1d_ncw),
669+
wrap_topi_schedule(topi.generic.schedule_group_conv1d_ncw),
670+
name="group_conv1d_ncw.generic",
671+
)
672+
elif layout == "NWC":
673+
strategy.add_implementation(
674+
wrap_compute_conv1d(topi.nn.group_conv1d_nwc),
675+
wrap_topi_schedule(topi.generic.schedule_group_conv1d_nwc),
676+
name="group_conv1d_nwc.generic",
677+
)
678+
else:
679+
raise ValueError("Unsupported conv1d layout {}".format(layout))
680+
return strategy
681+
682+
640683
# conv1d_transpose
641684
def wrap_compute_conv1d_transpose(topi_compute):
642685
"""wrap conv1d_transpose topi compute"""

python/tvm/relay/op/strategy/x86.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -360,24 +360,41 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
360360
def conv1d_strategy_cpu(attrs, inputs, out_type, target):
361361
"""conv1d x86 strategy"""
362362
layout = attrs.data_layout
363+
groups = attrs.groups
363364
dilation = get_const_tuple(attrs.dilation)
364365
if dilation[0] < 1:
365366
raise ValueError("dilation should be a positive value")
366367
strategy = _op.OpStrategy()
367-
if layout == "NCW":
368-
strategy.add_implementation(
369-
wrap_compute_conv1d(topi.nn.conv1d_ncw),
370-
wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
371-
name="conv1d_ncw.x86",
372-
)
373-
elif layout == "NWC":
374-
strategy.add_implementation(
375-
wrap_compute_conv1d(topi.nn.conv1d_nwc),
376-
wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
377-
name="conv1d_nwc.x86",
378-
)
368+
if groups == 1:
369+
if layout == "NCW":
370+
strategy.add_implementation(
371+
wrap_compute_conv1d(topi.nn.conv1d_ncw),
372+
wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
373+
name="conv1d_ncw.x86",
374+
)
375+
elif layout == "NWC":
376+
strategy.add_implementation(
377+
wrap_compute_conv1d(topi.nn.conv1d_nwc),
378+
wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
379+
name="conv1d_nwc.x86",
380+
)
381+
else:
382+
raise ValueError("Unsupported conv1d layout {}".format(layout))
379383
else:
380-
raise ValueError("Unsupported conv1d layout {}".format(layout))
384+
if layout == "NCW":
385+
strategy.add_implementation(
386+
wrap_compute_group_conv1d(topi.nn.group_conv1d_ncw),
387+
wrap_topi_schedule(topi.x86.schedule_group_conv1d_ncw),
388+
name="group_conv1d_ncw.x86",
389+
)
390+
elif layout == "NWC":
391+
strategy.add_implementation(
392+
wrap_compute_group_conv1d(topi.nn.group_conv1d_nwc),
393+
wrap_topi_schedule(topi.x86.schedule_group_conv1d_nwc),
394+
name="group_conv1d_nwc.x86",
395+
)
396+
else:
397+
raise ValueError("Unsupported conv1d layout {}".format(layout))
381398
return strategy
382399

383400

python/tvm/topi/cuda/conv1d.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def conv1d_ncw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32
2929
return nn.conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
3030

3131

32-
@autotvm.register_topi_schedule("conv1d_ncw.cuda")
33-
def schedule_conv1d_ncw(cfg, outs):
32+
def _schedule_conv1d_ncw(cfg, outs):
3433
"""TOPI schedule callback of conv1d ncw for cuda gpu
3534
3635
Parameters
@@ -51,7 +50,7 @@ def schedule_conv1d_ncw(cfg, outs):
5150
s = te.create_schedule([x.op for x in outs])
5251

5352
def _callback(op):
54-
if op.tag == "conv1d_ncw":
53+
if op.tag == "conv1d_ncw" or op.tag == "group_conv1d_ncw":
5554
pad_data = op.input_tensors[0]
5655
kernel = op.input_tensors[1]
5756
conv = op.output(0)
@@ -140,13 +139,27 @@ def _callback(op):
140139
return s
141140

142141

142+
@autotvm.register_topi_schedule("conv1d_ncw.cuda")
143+
def schedule_conv1d_ncw(cfg, outs):
144+
return _schedule_conv1d_ncw(cfg, outs)
145+
146+
147+
@autotvm.register_topi_compute("group_conv1d_ncw.cuda")
148+
def group_conv1d_ncw(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
149+
return nn.group_conv1d_ncw(data, kernel, strides, padding, dilation, groups, out_dtype)
150+
151+
152+
@autotvm.register_topi_schedule("group_conv1d_ncw.cuda")
153+
def schedule_group_conv1d_ncw(cfg, outs):
154+
return _schedule_conv1d_ncw(cfg, outs)
155+
156+
143157
@autotvm.register_topi_compute("conv1d_nwc.cuda")
144158
def conv1d_nwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
145159
return nn.conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
146160

147161

148-
@autotvm.register_topi_schedule("conv1d_nwc.cuda")
149-
def schedule_conv1d_nwc(cfg, outs):
162+
def _schedule_conv1d_nwc(cfg, outs):
150163
"""TOPI schedule callback of conv1d nwc for cuda gpu
151164
152165
Parameters
@@ -167,7 +180,7 @@ def schedule_conv1d_nwc(cfg, outs):
167180
s = te.create_schedule([x.op for x in outs])
168181

169182
def _callback(op):
170-
if op.tag == "conv1d_nwc":
183+
if op.tag == "conv1d_nwc" or op.tag == "group_conv1d_nwc":
171184
pad_data = op.input_tensors[0]
172185
kernel = op.input_tensors[1]
173186
conv = op.output(0)
@@ -254,3 +267,18 @@ def _callback(op):
254267
traverse_inline(s, outs[0].op, _callback)
255268

256269
return s
270+
271+
272+
@autotvm.register_topi_schedule("conv1d_nwc.cuda")
273+
def schedule_conv1d_nwc(cfg, outs):
274+
return _schedule_conv1d_nwc(cfg, outs)
275+
276+
277+
@autotvm.register_topi_compute("group_conv1d_nwc.cuda")
278+
def group_conv1d_nwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
279+
return nn.group_conv1d_nwc(data, kernel, strides, padding, dilation, groups, out_dtype)
280+
281+
282+
@autotvm.register_topi_schedule("group_conv1d_nwc.cuda")
283+
def schedule_group_conv1d_nwc(cfg, outs):
284+
return _schedule_conv1d_nwc(cfg, outs)

python/tvm/topi/generic/nn.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,40 @@ def schedule_conv1d_nwc(outs):
5454
return _default_schedule(outs, False)
5555

5656

57+
def schedule_group_conv1d_ncw(outs):
58+
"""Schedule for group_conv1d_ncw
59+
60+
Parameters
61+
----------
62+
outs: Array of Tensor
63+
The computation graph description of group_conv1d_ncw
64+
in the format of an array of tensors.
65+
66+
Returns
67+
-------
68+
sch: Schedule
69+
The computation schedule for the op.
70+
"""
71+
return _default_schedule(outs, False)
72+
73+
74+
def schedule_group_conv1d_nwc(outs):
75+
"""Schedule for group_conv1d_nwc
76+
77+
Parameters
78+
----------
79+
outs: Array of Tensor
80+
The computation graph description of group_conv1d_nwc
81+
in the format of an array of tensors.
82+
83+
Returns
84+
-------
85+
sch: Schedule
86+
The computation schedule for the op.
87+
"""
88+
return _default_schedule(outs, False)
89+
90+
5791
def schedule_conv2d_hwcn(outs):
5892
"""Schedule for conv2d_hwcn
5993

0 commit comments

Comments
 (0)