Skip to content

Commit 4f8f308

Browse files
committed
Squashed commit of the following:
commit 1711be3 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 13:11:40 2022 +0900 fixed condition for real commit 8a48fb5 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 09:57:42 2022 +0900 Revert "Skip applying sch_rule when both ann and sch_rule are defined" This reverts commit 4915c6a. commit daea033 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 09:31:05 2022 +0900 [Metaschedule] Support rocm and spirv commit eb0cae2 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 07:25:04 2022 +0900 dp4a works commit 4915c6a Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 06:13:45 2022 +0900 Skip applying sch_rule when both ann and sch_rule are defined commit 7b3d71c Author: Masahiro Masuda <[email protected]> Date: Wed Apr 13 04:40:31 2022 +0900 fixed intrin description commit 7666cd7 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:59:47 2022 +0900 add DP4A intrin commit 7086bdb Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 19:03:44 2022 +0900 works commit db34397 Author: Masahiro Masuda <[email protected]> Date: Tue Apr 12 12:49:52 2022 +0900 more hack to tensorize loop mapping to make resnet50 e2e work commit 2409674 Author: Masahiro Masuda <[email protected]> Date: Mon Apr 11 13:40:59 2022 +0900 wip support pad + qnn.conv2d folding commit 613cb7e Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 12:04:08 2022 +0900 hack to tensorize loop mapping to make conv2d work commit 9e4f9df Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:34:13 2022 +0900 wrap tensorize with try/catch commit d4b496d Author: Masahiro Masuda <[email protected]> Date: Sun Apr 10 11:33:39 2022 +0900 revert change in task_scheduler.cc commit 476129b Author: Masahiro Masuda <[email protected]> Date: Sat Apr 9 05:54:10 2022 +0900 try / catch in ThreadedApply commit d8226ff Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 17:17:59 2022 +0900 filter out invalid candidate commit 2632899 Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 10:09:48 2022 +0900 try graceful exit in parallel_for_dynamic commit 9d6741c Author: Masahiro Masuda <[email protected]> Date: Fri Apr 8 09:35:51 2022 +0900 [QNN] Fix broadcast for invalid axis commit 6ccde09 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:51:15 2022 +0900 refactor rewrite_tensorize commit 2ce2066 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 20:48:17 2022 +0900 allow missing schedule_rule in post order apply commit 3a69353 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 19:42:48 2022 +0900 refactor rewrite_tensorize commit 43e0b2f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:25:14 2022 +0900 rewrite_vnni -> rewrite_tensorize commit 823797e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 18:12:12 2022 +0900 VNNI -> WithIntrin commit 4284a47 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:45:41 2022 +0900 introduce TileForIntrin commit b87ef32 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:34:04 2022 +0900 move TilingwithTensorIntrin to auto_tensorize.cc commit 2fc118b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:28:45 2022 +0900 clean up headers commit d8b2aa3 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:09:32 2022 +0900 clean up using namespace commit eb05d25 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 17:03:05 2022 +0900 refactored init commit 5e6b0a0 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 16:57:14 2022 +0900 compiled commit 2b8c430 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 12:51:55 2022 +0900 wip MultiLevelTiling refactor commit 7c21a9f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:58:33 2022 +0900 function doc string not supported by tvmscript commit 40f9742 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:56:45 2022 +0900 update vnni intrin name commit 4814f82 Merge: e0c5eb8 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:44:47 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 07bbb38 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 15e60b4 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 7a757fe Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 9a3e508 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit d8e43ec Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 625cd27 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 69e72b6 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 1351fde Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0ced85f Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit 38a5aca Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 88b763e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 711a007 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e0c5eb8 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:42:26 2022 +0900 merge fix commit b171748 Merge: 71fe3bd 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:33:59 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 71fe3bd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:57:38 2022 +0900 move tensor intrin under tir commit 0c51bad Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:12:39 2022 +0900 remove log commit fed910e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:11:22 2022 +0900 more revert commit 7150aff Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:44 2022 +0900 revert stmt_functor change commit 155107b Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 06:10:09 2022 +0900 refactored RewriteVNNI a bit commit ca15255 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:41:13 2022 +0900 add RewriteVNNI commit dc9f71d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 05:38:56 2022 +0900 vectorized init loop commit fcc31ee Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 04:55:36 2022 +0900 tensorize worked commit 2b53437 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 19:11:05 2022 +0900 TilingwithTensorIntrin works commit 86baa31 Author: Masahiro Masuda <[email protected]> Date: Wed Apr 6 08:58:27 2022 +0900 Ported auto-tensorization code commit 82e152a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 88d9bdd Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:16:08 2022 +0900 black commit 31fe7eb Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 7876754 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit 56f2e9a Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 995cc8d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 86bbd49 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 120fd96 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0f0682d Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit f88c31e Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 6cc8009 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 11a29c7 Author: Masahiro Masuda <[email protected]> Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR
1 parent 597000c commit 4f8f308

File tree

20 files changed

+902
-196
lines changed

20 files changed

+902
-196
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ class ScheduleRule : public runtime::ObjectRef {
150150
Optional<Array<Integer>> vector_load_lens, //
151151
Optional<Map<String, ObjectRef>> reuse_read, //
152152
Optional<Map<String, ObjectRef>> reuse_write);
153+
154+
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
155+
String intrin_name, //
156+
String structure, //
157+
Optional<Array<String>> tile_binds, //
158+
Optional<Integer> max_innermost_factor, //
159+
Optional<Array<Integer>> vector_load_lens, //
160+
Optional<Map<String, ObjectRef>> reuse_read, //
161+
Optional<Map<String, ObjectRef>> reuse_write);
162+
153163
/*!
154164
* \brief Create a rule: add-rfactor to some blocks if needed
155165
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the

include/tvm/tir/stmt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
15091509
/*! \brief Mark auto-unroll setting on the block. */
15101510
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
15111511

1512+
/*!
1513+
* \brief Mark that the block should be further rewritten using tensorization.
1514+
*/
1515+
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1516+
15121517
/*!
15131518
* \brief Check if attr_key is a pragma key extension
15141519
* \param attr_key The attr key to be compared

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from .rewrite_reduction_block import RewriteReductionBlock
2323
from .rewrite_unbound_block import RewriteUnboundBlock
2424
from .verify_gpu_code import VerifyGPUCode
25+
from .rewrite_tensorize import RewriteTensorize
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that tensorize related components."""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
import tvm.tir.tensor_intrin
23+
24+
25+
@register_object("meta_schedule.RewriteTensorize")
26+
class RewriteTensorize(Postproc):
27+
"""A postprocessor that tensorize related components."""
28+
29+
def __init__(self, vectorize_init_loop=False) -> None:
30+
self.__init_handle_by_constructor__(
31+
_ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member
32+
vectorize_init_loop
33+
)

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .add_rfactor import AddRFactor
2323
from .auto_inline import AutoInline
2424
from .cross_thread_reduction import CrossThreadReduction
25-
from .multi_level_tiling import MultiLevelTiling, ReuseType
25+
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
2626
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
2727
from .random_compute_location import RandomComputeLocation
2828
from .schedule_rule import PyScheduleRule, ScheduleRule

python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,50 @@ def __init__(
8282
reuse_read.as_dict() if reuse_read is not None else None,
8383
reuse_write.as_dict() if reuse_write is not None else None,
8484
)
85+
86+
87+
@register_object("meta_schedule.MultiLevelTilingWithIntrin")
88+
class MultiLevelTilingWithIntrin(ScheduleRule):
89+
"""Multi-level tiling with reuse.
90+
91+
Parameters
92+
----------
93+
structure : str
94+
The tiling structure. Recommended:
95+
- 'SSRSRS' on CPU
96+
- 'SSSRRSRS' on GPU
97+
tile_bind : Optional[List[str]]
98+
For each level of tiles, which thread axis it is bound to. Recommended:
99+
- None on CPU
100+
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
101+
max_innermost_factor : Optional[int]
102+
The maximum size of the innermost factor. None means no limit
103+
vector_load_lens : Optional[List[int]]
104+
The length of vector lane in vectorized cooperative fetching.
105+
None means disable vectorization
106+
reuse_read : Optional[ReuseType]
107+
Data reuse configuration for reading. None means no reuse.
108+
reuse_write : Optional[ReuseType]
109+
Data reuse configuration for writing. None means no reuse.
110+
"""
111+
112+
def __init__(
113+
self,
114+
intrin_name: str,
115+
structure: str,
116+
tile_binds: Optional[List[str]] = None,
117+
max_innermost_factor: Optional[int] = None,
118+
vector_load_lens: Optional[List[int]] = None,
119+
reuse_read: Optional[ReuseType] = None,
120+
reuse_write: Optional[ReuseType] = None,
121+
) -> None:
122+
self.__init_handle_by_constructor__(
123+
_ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member
124+
intrin_name,
125+
structure,
126+
tile_binds,
127+
max_innermost_factor,
128+
vector_load_lens,
129+
reuse_read.as_dict() if reuse_read is not None else None,
130+
reuse_write.as_dict() if reuse_write is not None else None,
131+
)

python/tvm/meta_schedule/tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche
411411
# pylint: disable=protected-access
412412
if target.kind.name == "llvm":
413413
return DefaultLLVM._sch_rules()
414-
if target.kind.name == "cuda":
414+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
415415
return DefaultCUDA._sch_rules()
416416
# pylint: enable=protected-access
417417
raise ValueError(f"Unsupported target: {target}")
@@ -425,7 +425,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
425425
# pylint: disable=protected-access
426426
if target.kind.name == "llvm":
427427
return DefaultLLVM._postproc()
428-
if target.kind.name == "cuda":
428+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
429429
return DefaultCUDA._postproc()
430430
# pylint: enable=protected-access
431431
raise ValueError(f"Unsupported target: {target}")
@@ -444,7 +444,7 @@ def _mutator_probs(
444444
# pylint: disable=protected-access
445445
if target.kind.name == "llvm":
446446
return DefaultLLVM._mutator_probs()
447-
if target.kind.name == "cuda":
447+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
448448
return DefaultCUDA._mutator_probs()
449449
# pylint: enable=protected-access
450450
raise ValueError(f"Unsupported target: {target}")

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
"""Intrinsics for tensorization."""
1919
from .x86 import *
2020
from .arm_cpu import *
21+
from .dot_product_common import *
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Dot product related intrinsics."""
19+
from tvm.script import tir as T
20+
from .. import TensorIntrin
21+
22+
23+
@T.prim_func
24+
def dp4a_desc(
25+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
26+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
27+
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
28+
) -> None:
29+
with T.block("root"):
30+
T.reads(C[0], A[0:4], B[0:4])
31+
T.writes(C[0])
32+
for i in range(0, 4):
33+
with T.block("update"):
34+
vi = T.axis.remap("R", [i])
35+
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
36+
37+
38+
@T.prim_func
39+
def dp4a_impl(
40+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
41+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
42+
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
43+
) -> None:
44+
with T.block("root"):
45+
T.reads(C[0], A[0:4], B[0:4])
46+
T.writes(C[0])
47+
48+
C[0] += T.call_pure_extern(
49+
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
50+
)
51+
52+
53+
DP4A_INTRIN = "dp4a"
54+
55+
TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <algorithm>
20+
21+
#include "../utils.h"
22+
#include "tvm/runtime/container/base.h"
23+
24+
namespace tvm {
25+
namespace meta_schedule {
26+
27+
using tir::BlockRV;
28+
using tir::LoopRV;
29+
30+
void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
31+
const tir::PrimFuncNode* func, bool vectorize_init_loop) {
32+
std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;
33+
34+
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) -> bool {
35+
if (const auto* block = obj.as<tir::BlockNode>()) {
36+
tir::StmtSRef block_sref = sch->GetSRef(block);
37+
if (Optional<String> intrin_name =
38+
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
39+
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
40+
if (block_name.find("init") == std::string::npos) {
41+
jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) {
42+
try {
43+
sch->Tensorize(block, intrin_name.value());
44+
} catch (const std::exception& e) {
45+
LOG(WARNING) << "Tensorize failed with error " << e.what();
46+
}
47+
});
48+
} else if (vectorize_init_loop) {
49+
jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
50+
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
51+
ICHECK(child_blocks.size() == 1);
52+
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
53+
ICHECK(init_loops.size() == 1);
54+
sch->Vectorize(init_loops[0]);
55+
});
56+
}
57+
}
58+
}
59+
return true;
60+
});
61+
62+
for (auto kv : jobs) {
63+
tir::BlockRV block = sch->GetBlock(kv.first, func_name);
64+
sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
65+
kv.second(block);
66+
}
67+
}
68+
69+
class RewriteTensorizeNode : public PostprocNode {
70+
public:
71+
void InitializeWithTuneContext(const TuneContext& context) final {}
72+
73+
bool Apply(const tir::Schedule& sch) final;
74+
75+
void VisitAttrs(tvm::AttrVisitor* v) {}
76+
77+
bool vectorize_init_loop = false;
78+
79+
static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
80+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode);
81+
};
82+
83+
bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
84+
for (const auto& kv : sch->mod()->functions) {
85+
GlobalVar g_var = kv.first;
86+
BaseFunc base_func = kv.second;
87+
if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
88+
ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop);
89+
}
90+
}
91+
return true;
92+
}
93+
94+
Postproc RewriteTensorize(bool vectorize_init_loop) {
95+
ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
96+
n->vectorize_init_loop = vectorize_init_loop;
97+
return Postproc(n);
98+
}
99+
100+
TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
101+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize").set_body_typed(RewriteTensorize);
102+
103+
} // namespace meta_schedule
104+
} // namespace tvm

0 commit comments

Comments
 (0)