Skip to content

Commit 89c10c7

Browse files
[MetaSchedule] Developer Ergonomics Enhancement II
Follow-up of apache#11622, per discussion with @Kathryn-cat - [x] Allow using a string `"default"` in `TuneContext` to quickly specify a set of target-specific rules - [x] Enhance detection of `ScheduleFn` in `TuneContext` to make it easier for users to quickly try out template-driven scheduling on TIR. Next PR: - Add `TuneContext.tune` to allow directly tuning without task scheduler. Co-Authored-By: Kathryn (Jinqi) Chen <[email protected]>
1 parent 1312658 commit 89c10c7

27 files changed

+63
-59
lines changed

python/tvm/meta_schedule/space_generator/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Meta Schedule design space generators that generates design
2020
space for generation of measure candidates.
2121
"""
22-
from .space_generator import SpaceGenerator, PySpaceGenerator
23-
from .space_generator_union import SpaceGeneratorUnion
24-
from .schedule_fn import ScheduleFn
2522
from .post_order_apply import PostOrderApply
23+
from .schedule_fn import SCH_FN_TYPE, ScheduleFn
24+
from .space_generator import PySpaceGenerator, SpaceGenerator
25+
from .space_generator_union import SpaceGeneratorUnion

python/tvm/meta_schedule/space_generator/schedule_fn.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,17 @@
3030
if TYPE_CHECKING:
3131
from ..tune_context import TuneContext
3232

33+
SCH_FN_TYPE = Union[ # pylint: disable=invalid-name
34+
Callable[[Schedule], None], # No output
35+
Callable[[Schedule], Schedule], # Single output
36+
Callable[[Schedule], List[Schedule]], # Multiple outputs
37+
]
38+
3339

3440
@derived_object
3541
class ScheduleFn(PySpaceGenerator):
3642
"""A design space generator with design spaces specified by a schedule function."""
3743

38-
# Multiple cases of schedule functions supported
39-
SCH_FN_TYPE = Union[
40-
Callable[[Schedule], None], # No output
41-
Callable[[Schedule], Schedule], # Single output
42-
Callable[[Schedule], List[Schedule]], # Multiple outputs
43-
]
44-
4544
def __init__(self, sch_fn: SCH_FN_TYPE):
4645
"""Constructor.
4746

python/tvm/meta_schedule/tune.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TuneConfig(NamedTuple):
8888
search_strategy_config: Optional[Dict[str, Any]] = None
8989
logger_config: Optional[Dict[str, Any]] = None
9090

91-
def create_strategy(self, **kwargs):
91+
def create_strategy(self):
9292
"""Create search strategy from configuration"""
9393
cls_tbl = {
9494
"evolutionary": EvolutionarySearch,
@@ -111,7 +111,6 @@ def create_strategy(self, **kwargs):
111111
return cls_tbl[self.strategy](
112112
num_trials_per_iter=self.num_trials_per_iter,
113113
max_trials_per_task=max_trials_per_task,
114-
**kwargs,
115114
**config,
116115
)
117116

python/tvm/meta_schedule/tune_context.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Meta Schedule tuning context."""
1818

1919
import logging
20-
from typing import TYPE_CHECKING, Dict, List, Optional
20+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
2121

2222
from tvm import IRModule
2323
from tvm._ffi import register_object
@@ -36,7 +36,8 @@
3636
from .runner import RunnerResult
3737
from .schedule_rule import ScheduleRule
3838
from .search_strategy import MeasureCandidate, SearchStrategy
39-
from .space_generator import SpaceGenerator
39+
from .space_generator import SCH_FN_TYPE, SpaceGenerator
40+
from .tune import TuneConfig
4041

4142

4243
@register_object("meta_schedule.TuneContext")
@@ -54,16 +55,24 @@ class TuneContext(Object):
5455
The workload to be optimized.
5556
target : Optional[Target] = None
5657
The target to be optimized for.
57-
space_generator : Optional[SpaceGenerator] = None
58+
space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None
5859
The design space generator.
59-
search_strategy : Optional[SearchStrategy] = None
60+
search_strategy : Union[None, TuneConfig, SearchStrategy] = None
6061
The search strategy.
61-
sch_rules: Optional[List[ScheduleRule]] = None,
62+
if None, the strategy is left blank.
63+
If TuneConfig, the strategy is initialized with the TuneConfig.create_strategy().
64+
sch_rules: Union[None, str, List[ScheduleRule]] = None,
6265
The schedule rules.
63-
postprocs: Optional[List[Postproc"]] = None,
66+
If None, use an empty list of rules.
67+
if "default", use target-default rules.
68+
postprocs: Union[None, str, List[Postproc"]] = None,
6469
The postprocessors.
65-
mutator_probs: Optional[Dict[Mutator, float]]
70+
If None, use an empty list of rules.
71+
if "default", use target-default rules.
72+
mutator_probs: Union[None, str, Dict[Mutator, float]]
6673
Mutators and their probability mass.
74+
If None, use an empty list of rules.
75+
if "default", use target-default rules.
6776
task_name : Optional[str] = None
6877
The name of the tuning task.
6978
logger : logging.Logger
@@ -99,24 +108,53 @@ def __init__(
99108
mod: Optional[IRModule] = None,
100109
*,
101110
target: Optional[Target] = None,
102-
space_generator: Optional["SpaceGenerator"] = None,
103-
search_strategy: Optional["SearchStrategy"] = None,
104-
sch_rules: Optional[List["ScheduleRule"]] = None,
105-
postprocs: Optional[List["Postproc"]] = None,
106-
mutator_probs: Optional[Dict["Mutator", float]] = None,
111+
space_generator: Union[None, "SCH_FN_TYPE", "SpaceGenerator"] = None,
112+
search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None,
113+
sch_rules: Union[None, str, List["ScheduleRule"]] = None,
114+
postprocs: Union[None, str, List["Postproc"]] = None,
115+
mutator_probs: Union[None, str, Dict["Mutator", float]] = None,
107116
task_name: str = "main",
108117
logger: Optional[logging.Logger] = None,
109118
rand_state: int = -1,
110119
num_threads: Optional[int] = None,
111120
):
121+
# pylint: disable=import-outside-toplevel
122+
from . import default_config
123+
from .space_generator import ScheduleFn
124+
from .tune import TuneConfig
125+
126+
# pylint: enable=import-outside-toplevel
112127
if isinstance(mod, PrimFunc):
113128
mod = IRModule.from_expr(mod)
114-
if num_threads is None:
115-
num_threads = cpu_count()
129+
if callable(space_generator):
130+
space_generator = ScheduleFn(space_generator)
131+
if isinstance(search_strategy, TuneConfig):
132+
search_strategy = search_strategy.create_strategy()
133+
if isinstance(sch_rules, str):
134+
if sch_rules == "default":
135+
if target is None:
136+
raise ValueError("target is required when sch_rules is 'default'")
137+
sch_rules = default_config.schedule_rules(None, target)
138+
else:
139+
raise ValueError("sch_rules should be a list of ScheduleRule or 'default'")
140+
if isinstance(postprocs, str):
141+
if postprocs == "default":
142+
if target is None:
143+
raise ValueError("target is required when postprocs is 'default'")
144+
postprocs = default_config.postproc(None, target)
145+
else:
146+
raise ValueError("postprocs should be a list of Postproc or 'default'")
147+
if isinstance(mutator_probs, str):
148+
if mutator_probs == "default":
149+
if target is None:
150+
raise ValueError("target is required when mutator_probs is 'default'")
151+
mutator_probs = default_config.mutator_probs(None, target)
116152
if logger is None:
117153
self.logger = logging.getLogger(__name__)
118154
else:
119155
self.logger = None
156+
if num_threads is None:
157+
num_threads = cpu_count()
120158
self.__init_handle_by_constructor__(
121159
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
122160
mod,
@@ -131,9 +169,6 @@ def __init__(
131169
rand_state,
132170
num_threads,
133171
)
134-
135-
def initialize(self):
136-
"""Initialize the tuning context"""
137172
_ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member
138173

139174
def generate_design_space(self) -> List[Schedule]:

tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def test_conv2d_winograd_cpu():
173173
target,
174174
),
175175
)
176-
context.initialize()
177176
post_order_apply = context.space_generator
178177
(sch,) = post_order_apply.generate_design_space(mod)
179178
decisions = dict(

tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def test_conv2d_winograd_cuda():
290290
None, Target("cuda")
291291
),
292292
)
293-
context.initialize()
294293
post_order_apply = context.space_generator
295294
(sch,) = post_order_apply.generate_design_space(mod)
296295
decisions = dict(

tests/python/unittest/test_meta_schedule_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def traverse(t):
221221
mod,
222222
target="llvm",
223223
params=params,
224-
filter_func=filter_func,
224+
te_filter_func=filter_func,
225225
)
226226
expected_task_names = [
227227
"fused_" + s

tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def _make_mutator(target: Target) -> Mutator:
6969
MutateComputeLocation(): 1.0,
7070
},
7171
)
72-
ctx.initialize()
7372
return list(ctx.mutator_probs.keys())[0]
7473

7574

tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator:
8787
MutateParallel(max_jobs_per_core): 1.0,
8888
},
8989
)
90-
ctx.initialize()
9190
return list(ctx.mutator_probs.keys())[0]
9291

9392

tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def _make_mutator(target: Target) -> Mutator:
7070
MutateThreadBinding(): 1.0,
7171
},
7272
)
73-
ctx.initialize()
7473
return list(ctx.mutator_probs.keys())[0]
7574

7675

0 commit comments

Comments
 (0)