Skip to content

Commit 0602f4a

Browse files
committed
Squashed commit of the following:
commit 65b8bcf Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 20:36:49 2022 +0900 [WIP] adding DP4A support to rocm commit 4f8f308 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 14:03:25 2022 +0900 Squashed commit of the following: commit 1711be3 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 13:11:40 2022 +0900 fixed condition for real commit 8a48fb5 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 09:57:42 2022 +0900 Revert "Skip applying sch_rule when both ann and sch_rule are defined" This reverts commit 4915c6a. commit daea033 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 09:31:05 2022 +0900 [Metaschedule] Support rocm and spirv commit eb0cae2 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 07:25:04 2022 +0900 dp4a works commit 4915c6a Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 06:13:45 2022 +0900 Skip applying sch_rule when both ann and sch_rule are defined commit 7b3d71c Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 04:40:31 2022 +0900 fixed intrin description commit 7666cd7 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:59:47 2022 +0900 add DP4A intrin commit 7086bdb Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:03:44 2022 +0900 works commit db34397 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 12:49:52 2022 +0900 more hack to tensorize loop mapping to make resnet50 e2e work commit 2409674 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 13:40:59 2022 +0900 wip support pad + qnn.conv2d folding commit 613cb7e Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 12:04:08 2022 +0900 hack to tensorize loop mapping to make conv2d work commit 9e4f9df Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:34:13 2022 +0900 wrap tensorize with try/catch commit d4b496d Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:33:39 2022 +0900 revert change in task_scheduler.cc commit 476129b Author: Masahiro Masuda <[email protected]> Date: Sat Apr 9 05:54:10 2022 +0900 try / catch in ThreadedApply commit d8226ff Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 17:17:59 2022 +0900 filter out invalid candidate commit 2632899 Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 10:09:48 2022 +0900 try graceful exit in parallel_for_dynamic commit 9d6741c Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 09:35:51 2022 +0900 [QNN] Fix broadcast for invalid axis commit 6ccde09 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:51:15 2022 +0900 refactor rewrite_tensorize commit 2ce2066 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:48:17 2022 +0900 allow missing schedule_rule in post order apply commit 3a69353 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 19:42:48 2022 +0900 refactor rewrite_tensorize commit 43e0b2f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:25:14 2022 +0900 rewrite_vnni -> rewrite_tensorize commit 823797e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:12:12 2022 +0900 VNNI -> WithIntrin commit 4284a47 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:45:41 2022 +0900 introduce TileForIntrin commit b87ef32 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:34:04 2022 +0900 move TilingwithTensorIntrin to auto_tensorize.cc commit 2fc118b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:28:45 2022 +0900 clean up headers commit d8b2aa3 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:09:32 2022 +0900 clean up using namespace commit eb05d25 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:03:05 2022 +0900 refactored init commit 5e6b0a0 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 16:57:14 2022 +0900 compiled commit 2b8c430 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 12:51:55 2022 +0900 wip MultiLevelTiling refactor commit 7c21a9f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:58:33 2022 +0900 function doc string not supported by tvmscript commit 40f9742 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:56:45 2022 +0900 update vnni intrin name commit 4814f82 Merge: e0c5eb8 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:44:47 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 15e60b4 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 7a757fe Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 9a3e508 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit d8e43ec Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 625cd27 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 69e72b6 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 1351fde Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0ced85f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit 38a5aca Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 88b763e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 711a007 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e0c5eb8 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:42:26 2022 +0900 merge fix commit b171748 Merge: 71fe3bd 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:33:59 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 71fe3bd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:57:38 2022 +0900 move tensor intrin under tir commit 0c51bad Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:12:39 2022 +0900 remove log commit fed910e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:11:22 2022 +0900 more revert commit 7150aff Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:44 2022 +0900 revert stmt_functor change commit 155107b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:09 2022 +0900 refactored RewriteVNNI a bit commit ca15255 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:41:13 2022 +0900 add RewriteVNNI commit dc9f71d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:38:56 2022 +0900 vectorized init loop commit fcc31ee Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 04:55:36 2022 +0900 tensorize worked commit 2b53437 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 19:11:05 2022 +0900 TilingwithTensorIntrin works commit 86baa31 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 08:58:27 2022 +0900 Ported auto-tensorization code commit 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 88d9bdd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 31fe7eb Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 7876754 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit 56f2e9a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 995cc8d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 86bbd49 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 120fd96 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0f0682d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit f88c31e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 6cc8009 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 11a29c7 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR
1 parent 1bfb9ca commit 0602f4a

File tree

25 files changed

+906
-206
lines changed

25 files changed

+906
-206
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ class ScheduleRule : public runtime::ObjectRef {
150150
Optional<Array<Integer>> vector_load_lens, //
151151
Optional<Map<String, ObjectRef>> reuse_read, //
152152
Optional<Map<String, ObjectRef>> reuse_write);
153+
154+
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
155+
String intrin_name, //
156+
String structure, //
157+
Optional<Array<String>> tile_binds, //
158+
Optional<Integer> max_innermost_factor, //
159+
Optional<Array<Integer>> vector_load_lens, //
160+
Optional<Map<String, ObjectRef>> reuse_read, //
161+
Optional<Map<String, ObjectRef>> reuse_write);
162+
153163
/*!
154164
* \brief Create a rule: add-rfactor to some blocks if needed
155165
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the

include/tvm/tir/stmt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
15091509
/*! \brief Mark auto-unroll setting on the block. */
15101510
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
15111511

1512+
/*!
1513+
* \brief Mark that the block should be further rewritten using tensorization.
1514+
*/
1515+
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1516+
15121517
/*!
15131518
* \brief Check if attr_key is a pragma key extension
15141519
* \param attr_key The attr key to be compared

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from .rewrite_reduction_block import RewriteReductionBlock
2323
from .rewrite_unbound_block import RewriteUnboundBlock
2424
from .verify_gpu_code import VerifyGPUCode
25+
from .rewrite_tensorize import RewriteTensorize
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that tensorize related components."""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
import tvm.tir.tensor_intrin
23+
24+
25+
@register_object("meta_schedule.RewriteTensorize")
26+
class RewriteTensorize(Postproc):
27+
"""A postprocessor that tensorize related components."""
28+
29+
def __init__(self, vectorize_init_loop=False) -> None:
30+
self.__init_handle_by_constructor__(
31+
_ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member
32+
vectorize_init_loop
33+
)

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .add_rfactor import AddRFactor
2323
from .auto_inline import AutoInline
2424
from .cross_thread_reduction import CrossThreadReduction
25-
from .multi_level_tiling import MultiLevelTiling, ReuseType
25+
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
2626
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
2727
from .random_compute_location import RandomComputeLocation
2828
from .schedule_rule import PyScheduleRule, ScheduleRule

python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,50 @@ def __init__(
8282
reuse_read.as_dict() if reuse_read is not None else None,
8383
reuse_write.as_dict() if reuse_write is not None else None,
8484
)
85+
86+
87+
@register_object("meta_schedule.MultiLevelTilingWithIntrin")
88+
class MultiLevelTilingWithIntrin(ScheduleRule):
89+
"""Multi-level tiling with reuse.
90+
91+
Parameters
92+
----------
93+
structure : str
94+
The tiling structure. Recommended:
95+
- 'SSRSRS' on CPU
96+
- 'SSSRRSRS' on GPU
97+
tile_bind : Optional[List[str]]
98+
For each level of tiles, which thread axis it is bound to. Recommended:
99+
- None on CPU
100+
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
101+
max_innermost_factor : Optional[int]
102+
The maximum size of the innermost factor. None means no limit
103+
vector_load_lens : Optional[List[int]]
104+
The length of vector lane in vectorized cooperative fetching.
105+
None means disable vectorization
106+
reuse_read : Optional[ReuseType]
107+
Data reuse configuration for reading. None means no reuse.
108+
reuse_write : Optional[ReuseType]
109+
Data reuse configuration for writing. None means no reuse.
110+
"""
111+
112+
def __init__(
113+
self,
114+
intrin_name: str,
115+
structure: str,
116+
tile_binds: Optional[List[str]] = None,
117+
max_innermost_factor: Optional[int] = None,
118+
vector_load_lens: Optional[List[int]] = None,
119+
reuse_read: Optional[ReuseType] = None,
120+
reuse_write: Optional[ReuseType] = None,
121+
) -> None:
122+
self.__init_handle_by_constructor__(
123+
_ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member
124+
intrin_name,
125+
structure,
126+
tile_binds,
127+
max_innermost_factor,
128+
vector_load_lens,
129+
reuse_read.as_dict() if reuse_read is not None else None,
130+
reuse_write.as_dict() if reuse_write is not None else None,
131+
)

python/tvm/meta_schedule/tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche
411411
# pylint: disable=protected-access
412412
if target.kind.name == "llvm":
413413
return DefaultLLVM._sch_rules()
414-
if target.kind.name == "cuda":
414+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
415415
return DefaultCUDA._sch_rules()
416416
# pylint: enable=protected-access
417417
raise ValueError(f"Unsupported target: {target}")
@@ -425,7 +425,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
425425
# pylint: disable=protected-access
426426
if target.kind.name == "llvm":
427427
return DefaultLLVM._postproc()
428-
if target.kind.name == "cuda":
428+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
429429
return DefaultCUDA._postproc()
430430
# pylint: enable=protected-access
431431
raise ValueError(f"Unsupported target: {target}")
@@ -444,7 +444,7 @@ def _mutator_probs(
444444
# pylint: disable=protected-access
445445
if target.kind.name == "llvm":
446446
return DefaultLLVM._mutator_probs()
447-
if target.kind.name == "cuda":
447+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
448448
return DefaultCUDA._mutator_probs()
449449
# pylint: enable=protected-access
450450
raise ValueError(f"Unsupported target: {target}")

python/tvm/relay/op/strategy/cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
145145
if layout == "NCHW":
146146
assert kernel_layout == "OIHW"
147147
if (
148-
(target.kind.name in ["cuda", "vulkan"])
148+
(target.kind.name in ["cuda", "vulkan", "rocm"])
149149
and data.dtype in ("int8", "uint8")
150150
and kernel.dtype in ("int8", "uint8")
151151
):
@@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
297297
Need to satisfy tensor core schedule."
298298
)
299299
elif (
300-
(target.kind.name in ["cuda", "vulkan"])
300+
(target.kind.name in ["cuda", "vulkan", "rocm"])
301301
and layout == "NCHW4c"
302302
and data.dtype in ["int8", "uint8"]
303303
):
@@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
376376
ic_chunk = in_channels // 4
377377

378378
if (
379-
(target.kind.name in ["cuda", "vulkan"])
379+
(target.kind.name in ["cuda", "vulkan", "rocm"])
380380
and data.dtype in ["int8", "uint8"]
381381
and kernel.dtype in ["int8", "uint8"]
382382
and channels % groups == 0
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836
b, i = get_const_tuple(data.shape)
837837
o, _ = get_const_tuple(weights.shape)
838838
if (
839-
target.kind.name in ["cuda", "vulkan"]
839+
target.kind.name in ["cuda", "vulkan", "rocm"]
840840
and data.dtype == "int8"
841841
and weights.dtype == "int8"
842842
and out_type.dtype == "int32"

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def is_aarch64_arm():
387387
return "aarch64" in target.attrs.get("mtriple", "")
388388

389389

390+
def is_rocm():
391+
"""Checks whether we are compiling for a rocm/spirv target."""
392+
target = tvm.target.Target.current(allow_none=False)
393+
return "rocm" in target.keys
394+
395+
390396
def is_vulkan():
391397
"""Checks whether we are compiling for a vulkan/spirv target."""
392398
target = tvm.target.Target.current(allow_none=False)
@@ -456,7 +462,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
456462

457463
@qnn_conv2d_legalize.register(["cuda", "gpu"])
458464
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
459-
if is_vulkan():
465+
if is_vulkan() or is_rocm():
460466
# prefers the dtypes to be same. Mixed type is not yet supported.
461467
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
462468
if is_cuda():
@@ -467,7 +473,7 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
467473

468474
@qnn_dense_legalize.register(["cuda", "gpu"])
469475
def _qnn_dense_legalize_cuda(attrs, inputs, types):
470-
if is_vulkan():
476+
if is_vulkan() or is_rocm():
471477
# prefers the dtypes to be same. Mixed type is not yet supported.
472478
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
473479
if is_cuda():

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
"""Intrinsics for tensorization."""
1919
from .x86 import *
2020
from .arm_cpu import *
21+
from .dot_product_common import *

0 commit comments

Comments
 (0)