Skip to content

Commit a34731b

Browse files
authored
[ROCM] DP4A intrinsic support for TE/TIR (#11009)
* [ROCM] Support dp4a on AMDGPU by sdot4 intrinsic commit 0225f2b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 08:56:10 2022 +0900 share op strategy between cuda and rocm commit 762c7e8 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 08:28:34 2022 +0900 fixed rocm batch_matmul strategy for mixed i8i8i32 commit ce53e8d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 06:17:30 2022 +0900 add rocm sdot4 TIR intrin commit f4562b9 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 06:03:44 2022 +0900 rocm sdot4 works commit 6cc6280 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 05:32:07 2022 +0900 more wip commit 0602f4a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 14 03:47:37 2022 +0900 Squashed commit of the following: commit 65b8bcf Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 20:36:49 2022 +0900 [WIP] adding DP4A support to rocm commit 4f8f308 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 14:03:25 2022 +0900 Squashed commit of the following: commit 1711be3 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 13:11:40 2022 +0900 fixed condition for real commit 8a48fb5 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 09:57:42 2022 +0900 Revert "Skip applying sch_rule when both ann and sch_rule are defined" This reverts commit 4915c6a. commit daea033 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 09:31:05 2022 +0900 [Metaschedule] Support rocm and spirv commit eb0cae2 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 07:25:04 2022 +0900 dp4a works commit 4915c6a Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 06:13:45 2022 +0900 Skip applying sch_rule when both ann and sch_rule are defined commit 7b3d71c Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 04:40:31 2022 +0900 fixed intrin description commit 7666cd7 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:59:47 2022 +0900 add DP4A intrin commit 7086bdb Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:03:44 2022 +0900 works commit db34397 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 12:49:52 2022 +0900 more hack to tensorize loop mapping to make resnet50 e2e work commit 2409674 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 13:40:59 2022 +0900 wip support pad + qnn.conv2d folding commit 613cb7e Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 12:04:08 2022 +0900 hack to tensorize loop mapping to make conv2d work commit 9e4f9df Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:34:13 2022 +0900 wrap tensorize with try/catch commit d4b496d Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:33:39 2022 +0900 revert change in task_scheduler.cc commit 476129b Author: Masahiro Masuda <[email protected]> Date: Sat Apr 9 05:54:10 2022 +0900 try / catch in ThreadedApply commit d8226ff Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 17:17:59 2022 +0900 filter out invalid candidate commit 2632899 Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 10:09:48 2022 +0900 try graceful exit in parallel_for_dynamic commit 9d6741c Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 09:35:51 2022 +0900 [QNN] Fix broadcast for invalid axis commit 6ccde09 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:51:15 2022 +0900 refactor rewrite_tensorize commit 2ce2066 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:48:17 2022 +0900 allow missing schedule_rule in post order apply commit 3a69353 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 19:42:48 2022 +0900 refactor rewrite_tensorize commit 43e0b2f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:25:14 2022 +0900 rewrite_vnni -> rewrite_tensorize commit 823797e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:12:12 2022 +0900 VNNI -> WithIntrin commit 4284a47 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:45:41 2022 +0900 introduce TileForIntrin commit b87ef32 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:34:04 2022 +0900 move TilingwithTensorIntrin to auto_tensorize.cc commit 2fc118b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:28:45 2022 +0900 clean up headers commit d8b2aa3 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:09:32 2022 +0900 clean up using namespace commit eb05d25 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:03:05 2022 +0900 refactored init commit 5e6b0a0 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 16:57:14 2022 +0900 compiled commit 2b8c430 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 12:51:55 2022 +0900 wip MultiLevelTiling refactor commit 7c21a9f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:58:33 2022 +0900 function doc string not supported by tvmscript commit 40f9742 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:56:45 2022 +0900 update vnni intrin name commit 4814f82 Merge: e0c5eb8 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:44:47 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 15e60b4 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 7a757fe Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 9a3e508 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit d8e43ec Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 625cd27 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 69e72b6 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 1351fde Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0ced85f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit 38a5aca Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 88b763e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 711a007 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e0c5eb8 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:42:26 2022 +0900 merge fix commit b171748 Merge: 71fe3bd 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:33:59 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 71fe3bd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:57:38 2022 +0900 move tensor intrin under tir commit 0c51bad Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:12:39 2022 +0900 remove log commit fed910e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:11:22 2022 +0900 more revert commit 7150aff Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:44 2022 +0900 revert stmt_functor change commit 155107b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:09 2022 +0900 refactored RewriteVNNI a bit commit ca15255 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:41:13 2022 +0900 add RewriteVNNI commit dc9f71d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:38:56 2022 +0900 vectorized init loop commit fcc31ee Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 04:55:36 2022 +0900 tensorize worked commit 2b53437 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 19:11:05 2022 +0900 TilingwithTensorIntrin works commit 86baa31 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 08:58:27 2022 +0900 Ported auto-tensorization code commit 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 88d9bdd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 31fe7eb Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 7876754 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit 56f2e9a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 995cc8d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 86bbd49 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 120fd96 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0f0682d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit f88c31e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 6cc8009 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 11a29c7 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR * cleanup * black * update dot prod intrin * add mattr kind * conv2d topi test working * add dense and bmm test * add conv2d relay test * add tir intrin test * pylint
1 parent 3d63b2d commit a34731b

File tree

20 files changed

+358
-273
lines changed

20 files changed

+358
-273
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
145145
if layout == "NCHW":
146146
assert kernel_layout == "OIHW"
147147
if (
148-
(target.kind.name in ["cuda", "vulkan"])
148+
(target.kind.name in ["cuda", "vulkan", "rocm"])
149149
and data.dtype in ("int8", "uint8")
150150
and kernel.dtype in ("int8", "uint8")
151151
):
@@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
297297
Need to satisfy tensor core schedule."
298298
)
299299
elif (
300-
(target.kind.name in ["cuda", "vulkan"])
300+
(target.kind.name in ["cuda", "vulkan", "rocm"])
301301
and layout == "NCHW4c"
302302
and data.dtype in ["int8", "uint8"]
303303
):
@@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
376376
ic_chunk = in_channels // 4
377377

378378
if (
379-
(target.kind.name in ["cuda", "vulkan"])
379+
(target.kind.name in ["cuda", "vulkan", "rocm"])
380380
and data.dtype in ["int8", "uint8"]
381381
and kernel.dtype in ["int8", "uint8"]
382382
and channels % groups == 0
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836
b, i = get_const_tuple(data.shape)
837837
o, _ = get_const_tuple(weights.shape)
838838
if (
839-
target.kind.name in ["cuda", "vulkan"]
839+
target.kind.name in ["cuda", "vulkan", "rocm"]
840840
and data.dtype == "int8"
841841
and weights.dtype == "int8"
842842
and out_type.dtype == "int32"

python/tvm/relay/op/strategy/rocm.py

Lines changed: 20 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -17,175 +17,48 @@
1717
"""Definition of ROCm operator strategy."""
1818
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
1919
from tvm import topi
20-
from tvm.auto_scheduler import is_auto_scheduler_enabled
2120
from tvm.te import SpecializedCondition
2221
from tvm.contrib.thrust import can_use_rocthrust
2322
from tvm.contrib import miopen
2423

2524
from .generic import *
2625
from .. import op as _op
27-
from .cuda import judge_winograd, naive_schedule
26+
from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda
2827

2928

3029
@conv2d_strategy.register("rocm")
3130
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
3231
"""conv2d rocm strategy"""
33-
strategy = _op.OpStrategy()
34-
data, kernel = inputs
35-
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
3632
groups = attrs.groups
3733
layout = attrs.data_layout
38-
stride_h, stride_w = attrs.get_int_tuple("strides")
39-
kernel_layout = attrs.kernel_layout
4034
padding = attrs.get_int_tuple("padding")
41-
if dilation_h < 1 or dilation_w < 1:
42-
raise ValueError("dilation should be positive value")
43-
44-
if groups == 1:
45-
if layout == "NCHW":
46-
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
47-
assert kernel_layout == "OIHW"
48-
strategy.add_implementation(
49-
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
50-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
51-
name="conv2d_nchw.cuda",
52-
)
53-
_, _, kh, kw = get_const_tuple(kernel.shape)
54-
if (
55-
2 < kh < 8
56-
and 2 < kw < 8
57-
and kh == kw
58-
and stride_h == 1
59-
and stride_w == 1
60-
and dilation_h == 1
61-
and dilation_w == 1
62-
):
63-
strategy.add_implementation(
64-
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
65-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
66-
name="conv2d_nchw_winograd.cuda",
67-
plevel=5,
68-
)
69-
elif layout == "NHWC":
70-
assert kernel_layout == "HWIO"
71-
strategy.add_implementation(
72-
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
73-
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
74-
name="conv2d_nhwc.gpu",
75-
)
76-
N, H, W, _ = get_const_tuple(data.shape)
77-
KH, KW, CI, CO = get_const_tuple(kernel.shape)
7835

79-
(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
80-
N,
81-
H,
82-
W,
83-
KH,
84-
KW,
85-
CI,
86-
CO,
87-
padding,
88-
stride_h,
89-
stride_w,
90-
dilation_h,
91-
dilation_w,
92-
data.dtype,
93-
kernel.dtype,
94-
pre_flag=False,
95-
)
36+
strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)
9637

97-
if judge_winograd_autotvm:
98-
strategy.add_implementation(
99-
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
100-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
101-
name="conv2d_nhwc_winograd_direct.cuda",
102-
plevel=5,
103-
)
38+
# add miopen implementation
39+
if (
40+
"miopen" in target.libs
41+
and groups == 1
42+
and layout == "NCHW"
43+
and padding[0] == padding[2]
44+
and padding[1] == padding[3]
45+
):
46+
strategy.add_implementation(
47+
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
48+
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
49+
name="conv2d_nchw_miopen.rocm",
50+
plevel=50,
51+
)
10452

105-
if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
106-
strategy.add_implementation(
107-
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
108-
naive_schedule, # this implementation should never be picked by autotvm
109-
name="conv2d_nhwc.winograd",
110-
plevel=15,
111-
)
112-
elif layout == "HWCN":
113-
assert kernel_layout == "HWIO"
114-
strategy.add_implementation(
115-
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
116-
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
117-
name="conv2d_hwcn.cuda",
118-
)
119-
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
120-
assert kernel_layout == "OIHW4o4i"
121-
strategy.add_implementation(
122-
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
123-
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
124-
name="conv2d_NCHWc_int8.cuda",
125-
)
126-
else:
127-
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
128-
# add miopen implementation
129-
if (
130-
"miopen" in target.libs
131-
and layout == "NCHW"
132-
and padding[0] == padding[2]
133-
and padding[1] == padding[3]
134-
):
135-
strategy.add_implementation(
136-
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
137-
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
138-
name="conv2d_nchw_miopen.rocm",
139-
plevel=15,
140-
)
141-
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
142-
if layout == "NCHW":
143-
assert kernel_layout == "OIHW"
144-
strategy.add_implementation(
145-
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
146-
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
147-
name="depthwise_conv2d_nchw.cuda",
148-
)
149-
elif layout == "NHWC":
150-
assert kernel_layout == "HWOI"
151-
strategy.add_implementation(
152-
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
153-
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
154-
name="depthwise_conv2d_nhwc.cuda",
155-
)
156-
else:
157-
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
158-
else: # group_conv2d
159-
if layout == "NCHW":
160-
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
161-
assert kernel_layout == "OIHW"
162-
strategy.add_implementation(
163-
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
164-
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
165-
name="group_conv2d_nchw.cuda",
166-
)
167-
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
168-
assert kernel_layout == "OIHW4o4i"
169-
strategy.add_implementation(
170-
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
171-
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
172-
name="group_conv2d_NCHWc_int8.cuda",
173-
)
174-
else:
175-
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
17653
return strategy
17754

17855

17956
@dense_strategy.register("rocm")
18057
def dense_strategy_rocm(attrs, inputs, out_type, target):
18158
"""Dense strategy for ROCM"""
18259
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
183-
strategy = _op.OpStrategy()
184-
strategy.add_implementation(
185-
wrap_compute_dense(topi.rocm.dense),
186-
wrap_topi_schedule(topi.rocm.schedule_dense),
187-
name="dense.rocm",
188-
)
60+
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)
61+
18962
if target.kind.name == "rocm" and "rocblas" in target.libs:
19063
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
19164
strategy.add_implementation(
@@ -200,13 +73,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
20073
@batch_matmul_strategy.register("rocm")
20174
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
20275
"""Batch matmul strategy for ROCM"""
203-
strategy = _op.OpStrategy()
204-
strategy.add_implementation(
205-
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
206-
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
207-
name="batch_matmul.cuda",
208-
plevel=10,
209-
)
76+
strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)
77+
21078
if target.kind.name == "rocm" and "rocblas" in target.libs:
21179
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
21280
strategy.add_implementation(

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op
2525

2626
from ....topi.x86.utils import target_has_sse42
27+
from ....topi.utils import is_target
2728
from .. import op as reg
2829

2930
#################################################
@@ -387,18 +388,6 @@ def is_aarch64_arm():
387388
return "aarch64" in target.attrs.get("mtriple", "")
388389

389390

390-
def is_vulkan():
391-
"""Checks whether we are compiling for a vulkan/spirv target."""
392-
target = tvm.target.Target.current(allow_none=False)
393-
return "vulkan" in target.keys
394-
395-
396-
def is_cuda():
397-
"""Checks whether we are compiling for a cuda target."""
398-
target = tvm.target.Target.current(allow_none=False)
399-
return "cuda" in target.keys
400-
401-
402391
########################
403392
# ARM CPU legalizations.
404393
########################
@@ -456,22 +445,21 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
456445

457446
@qnn_conv2d_legalize.register(["cuda", "gpu"])
458447
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
459-
if is_vulkan():
448+
if is_target("vulkan"):
460449
# prefers the dtypes to be same. Mixed type is not yet supported.
461450
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
462-
if is_cuda():
451+
if is_target(["cuda", "rocm"]):
463452
# CUDA prefers both datatypes to be int8.
464453
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d)
465454
return None
466455

467456

468457
@qnn_dense_legalize.register(["cuda", "gpu"])
469458
def _qnn_dense_legalize_cuda(attrs, inputs, types):
470-
if is_vulkan():
459+
if is_target("vulkan"):
471460
# prefers the dtypes to be same. Mixed type is not yet supported.
472461
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
473-
if is_cuda():
462+
if is_target(["cuda", "rocm"]):
474463
# CUDA prefers both datatypes to be the int8.
475464
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)
476-
477465
return None

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@
1818
"""Intrinsics for tensorization."""
1919
from .x86 import *
2020
from .arm_cpu import *
21+
from .dot_product_common import *
22+
from .rocm import *
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Dot product related intrinsics."""
19+
from tvm.script import tir as T
20+
from .. import TensorIntrin
21+
22+
23+
@T.prim_func
24+
def dp4a_desc(
25+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
26+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
27+
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
28+
) -> None:
29+
with T.block("root"):
30+
T.reads(C[0], A[0:4], B[0:4])
31+
T.writes(C[0])
32+
for i in range(0, 4):
33+
with T.block("update"):
34+
vi = T.axis.remap("R", [i])
35+
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
36+
37+
38+
@T.prim_func
39+
def dp4a_impl(
40+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
41+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
42+
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
43+
) -> None:
44+
with T.block("root"):
45+
T.reads(C[0], A[0:4], B[0:4])
46+
T.writes(C[0])
47+
48+
C[0] += T.call_pure_extern(
49+
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
50+
)
51+
52+
53+
DP4A_INTRIN = "dp4a"
54+
55+
TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)

0 commit comments

Comments
 (0)