Skip to content

Commit 7c06de5

Browse files
authored
[Fix][MetaSchedule] Fix redundant stages in async pipeline for mlt (#14143)
This PR fixes redundant stages if visiting `InitializeWithTuneContext` multiple times.
1 parent 7d67bb1 commit 7c06de5

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
9696
if (std::stoi(sm) >= 80) {
9797
// only stage = 4 & 5 is tested. all integer that is bigger than 2
9898
// is theoretically feasible, but no guarantee for great performance.
99-
this->stages.insert(this->stages.end(), {4, 5});
99+
this->stages = {4, 5};
100100
}
101101
} catch (const std::invalid_argument& e) {
102102
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Tests for MetaSchedule search space on CUDA"""
18+
from typing import List, Optional, Tuple, Union
19+
20+
# isort: off
21+
from typing_extensions import Literal
22+
23+
# isort: on
24+
from tvm.meta_schedule.testing.space_generation import get_rules
25+
from tvm import meta_schedule as ms
26+
from tvm.meta_schedule.testing.te_workload import create_te_workload
27+
from tvm.target import Target
28+
from tvm.ir import IRModule
29+
from tvm.tir import Schedule
30+
31+
32+
def generate_design_space(
33+
kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"],
34+
mod: IRModule,
35+
target: Target,
36+
types: Union[type, Tuple[type, ...]],
37+
sch_rules: Optional[List[ms.ScheduleRule]] = None,
38+
initialize_time: int = 1,
39+
) -> List[Schedule]:
40+
if sch_rules is None:
41+
sch_rules = get_rules(kind, types)
42+
else:
43+
assert types is None
44+
ctx = ms.TuneContext(
45+
mod=mod,
46+
target=target,
47+
space_generator=ms.space_generator.PostOrderApply(
48+
sch_rules=sch_rules,
49+
postprocs=[],
50+
mutator_probs={},
51+
),
52+
task_name="test",
53+
)
54+
# each time cloning will trigger one more initialization
55+
for _ in range(initialize_time - 1):
56+
ctx = ctx.clone()
57+
return ctx.generate_design_space()
58+
59+
60+
def _target():
61+
return Target("nvidia/geforce-rtx-3070")
62+
63+
64+
def _design_space(mod):
65+
return generate_design_space(
66+
kind="cuda",
67+
mod=mod,
68+
target=_target(),
69+
types=ms.ScheduleRule,
70+
initialize_time=100,
71+
)
72+
73+
74+
def test_c2d():
75+
mod = create_te_workload("C2D", 0)
76+
actual = _design_space(mod)
77+
assert len(actual) == 3
78+
79+
80+
def test_gmm():
81+
mod = create_te_workload("GMM", 0)
82+
actual = _design_space(mod)
83+
assert len(actual) == 3
84+
85+
86+
if __name__ == "__main__":
87+
test_c2d()
88+
test_gmm()

0 commit comments

Comments
 (0)