Skip to content

Commit 1ebde43

Browse files
GaryYuyjlLaurawly
authored andcommitted
[TOPI] Support int4/int8 conv2d tensor core with HWNC layout (apache#6121)
* int4 tensorcore * a draft for new int4 schedule * update layout * add inline option * clean code * increase search space * fix kernel shape * update intrinsic * update intrinsic * support int4/int8 hwnc layout * remove useless code * remove useless code * remove useless code * remove useless code * fix int8 transpose * fix assert * add asf header * CI * CI * CI * fix bug fix bug Co-authored-by: Leyuan Wang <[email protected]>
1 parent b04ae59 commit 1ebde43

File tree

6 files changed

+629
-2
lines changed

6 files changed

+629
-2
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,25 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
172172
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
173173
name="conv2d_nhwc_tensorcore.cuda",
174174
plevel=20)
175+
elif layout == "HWNC":
176+
assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"]
177+
_, _, N, in_channels = get_const_tuple(data.shape)
178+
pre_computed = len(kernel.shape) == 6
179+
if pre_computed:
180+
_, _, oc_chunk, _, oc_block_factor, _ = get_const_tuple(kernel.shape)
181+
out_channels = oc_chunk * oc_block_factor
182+
else:
183+
_, _, out_channels, _ = get_const_tuple(kernel.shape)
184+
if topi.cuda.is_shape_tensorcore_direct_qualified(
185+
batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype):
186+
strategy.add_implementation(
187+
wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore),
188+
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore),
189+
name="conv2d_hwnc_tensorcore_direct.cuda",
190+
plevel=20)
191+
else:
192+
raise RuntimeError("Unsupported shape for conv2d HWNC.\
193+
Need to satisfy tensor core schedule.")
175194
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
176195
assert kernel_layout == "OIHW4o4i"
177196
strategy.add_implementation(

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@
5050
from .conv2d_nhwc_tensorcore import *
5151
from .conv3d_ndhwc_tensorcore import *
5252
from .dense_tensorcore import *
53+
from .conv2d_hwnc_tensorcore import *
5354
from .correlation import *
5455
from .sparse import *

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
171171
dispatch_ctx.update(target, new_workload, cfg)
172172
return relay.nn.conv2d(*inputs, **new_attrs)
173173

174+
if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda":
175+
assert data_layout == "HWNC" and kernel_layout == "HWOI"
176+
assert float(tvm.gpu(0).compute_version) >= 7.5
177+
H, W, N, CI = get_const_tuple(data.shape)
178+
KH, KW, CO, _ = get_const_tuple(kernel.shape)
179+
180+
if kernel.dtype in ['int4', 'uint4'] and (CI % 32 != 0 or CO % 8 != 0) or \
181+
kernel.dtype in ['int8', 'uint8'] and (CI % 16 != 0 or CO % 32 != 0):
182+
return relay.nn.conv2d(*inputs, **new_attrs)
183+
184+
new_attrs["channels"] = CO
185+
if kernel.dtype in ['int4', 'uint4']:
186+
new_attrs['kernel_layout'] = 'HWOI8o32i'
187+
ic_block_factor = 32
188+
oc_block_factor = 8
189+
else:
190+
new_attrs['kernel_layout'] = 'HWOI32o16i'
191+
ic_block_factor = 16
192+
oc_block_factor = 32
193+
194+
new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI // ic_block_factor,
195+
oc_block_factor, ic_block_factor), dtype=kernel.dtype)
196+
197+
new_workload = autotvm.task.args_to_workload(
198+
[data, new_kernel, strides, padding, dilation, out_dtype],
199+
"conv2d_HWNCnc_tensorcore.cuda")
200+
201+
dispatch_ctx.update(target, new_workload, cfg)
202+
return relay.nn.conv2d(*inputs, **new_attrs)
203+
174204
return None
175205

176206
@conv2d_legalize.register("cuda")

0 commit comments

Comments
 (0)