1717"""Meta Schedule tuning context."""
1818
1919import logging
20- from typing import TYPE_CHECKING , Dict , List , Optional
20+ from typing import TYPE_CHECKING , Dict , List , Optional , Union
2121
2222from tvm import IRModule
2323from tvm ._ffi import register_object
3636 from .runner import RunnerResult
3737 from .schedule_rule import ScheduleRule
3838 from .search_strategy import MeasureCandidate , SearchStrategy
39- from .space_generator import SpaceGenerator
39+ from .space_generator import SCH_FN_TYPE , SpaceGenerator
40+ from .tune import TuneConfig
4041
4142
4243@register_object ("meta_schedule.TuneContext" )
@@ -54,16 +55,24 @@ class TuneContext(Object):
5455 The workload to be optimized.
5556 target : Optional[Target] = None
5657 The target to be optimized for.
57- space_generator : Optional[ SpaceGenerator] = None
58+ space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None
5859 The design space generator.
59- search_strategy : Optional[ SearchStrategy] = None
60+ search_strategy : Union[None, TuneConfig, SearchStrategy] = None
6061 The search strategy.
61- sch_rules: Optional[List[ScheduleRule]] = None,
62+ if None, the strategy is left blank.
63+ If TuneConfig, the strategy is initialized with the TuneConfig.create_strategy().
64+ sch_rules: Union[None, str, List[ScheduleRule]] = None,
6265 The schedule rules.
63- postprocs: Optional[List[Postproc"]] = None,
66+ If None, use an empty list of rules.
67+ if "default", use target-default rules.
68+ postprocs: Union[None, str, List[Postproc"]] = None,
6469 The postprocessors.
65- mutator_probs: Optional[Dict[Mutator, float]]
70+ If None, use an empty list of rules.
71+ if "default", use target-default rules.
72+ mutator_probs: Union[None, str, Dict[Mutator, float]]
6673 Mutators and their probability mass.
74+ If None, use an empty list of rules.
75+ if "default", use target-default rules.
6776 task_name : Optional[str] = None
6877 The name of the tuning task.
6978 logger : logging.Logger
@@ -99,24 +108,53 @@ def __init__(
99108 mod : Optional [IRModule ] = None ,
100109 * ,
101110 target : Optional [Target ] = None ,
102- space_generator : Optional [ "SpaceGenerator" ] = None ,
103- search_strategy : Optional [ "SearchStrategy" ] = None ,
104- sch_rules : Optional [ List ["ScheduleRule" ]] = None ,
105- postprocs : Optional [ List ["Postproc" ]] = None ,
106- mutator_probs : Optional [ Dict ["Mutator" , float ]] = None ,
111+ space_generator : Union [ None , "SCH_FN_TYPE" , "SpaceGenerator" ] = None ,
112+ search_strategy : Union [ None , "SearchStrategy" , "TuneConfig " ] = None ,
113+ sch_rules : Union [ None , str , List ["ScheduleRule" ]] = None ,
114+ postprocs : Union [ None , str , List ["Postproc" ]] = None ,
115+ mutator_probs : Union [ None , str , Dict ["Mutator" , float ]] = None ,
107116 task_name : str = "main" ,
108117 logger : Optional [logging .Logger ] = None ,
109118 rand_state : int = - 1 ,
110119 num_threads : Optional [int ] = None ,
111120 ):
121+ # pylint: disable=import-outside-toplevel
122+ from . import default_config
123+ from .space_generator import ScheduleFn
124+ from .tune import TuneConfig
125+
126+ # pylint: enable=import-outside-toplevel
112127 if isinstance (mod , PrimFunc ):
113128 mod = IRModule .from_expr (mod )
114- if num_threads is None :
115- num_threads = cpu_count ()
129+ if callable (space_generator ):
130+ space_generator = ScheduleFn (space_generator )
131+ if isinstance (search_strategy , TuneConfig ):
132+ search_strategy = search_strategy .create_strategy ()
133+ if isinstance (sch_rules , str ):
134+ if sch_rules == "default" :
135+ if target is None :
136+ raise ValueError ("target is required when sch_rules is 'default'" )
137+ sch_rules = default_config .schedule_rules (None , target )
138+ else :
139+ raise ValueError ("sch_rules should be a list of ScheduleRule or 'default'" )
140+ if isinstance (postprocs , str ):
141+ if postprocs == "default" :
142+ if target is None :
143+ raise ValueError ("target is required when postprocs is 'default'" )
144+ postprocs = default_config .postproc (None , target )
145+ else :
146+ raise ValueError ("postprocs should be a list of Postproc or 'default'" )
147+ if isinstance (mutator_probs , str ):
148+ if mutator_probs == "default" :
149+ if target is None :
150+ raise ValueError ("target is required when mutator_probs is 'default'" )
151+ mutator_probs = default_config .mutator_probs (None , target )
116152 if logger is None :
117153 self .logger = logging .getLogger (__name__ )
118154 else :
119155 self .logger = None
156+ if num_threads is None :
157+ num_threads = cpu_count ()
120158 self .__init_handle_by_constructor__ (
121159 _ffi_api .TuneContext , # type: ignore # pylint: disable=no-member
122160 mod ,
@@ -131,9 +169,6 @@ def __init__(
131169 rand_state ,
132170 num_threads ,
133171 )
134-
135- def initialize (self ):
136- """Initialize the tuning context"""
137172 _ffi_api .TuneContextInitialize (self ) # type: ignore # pylint: disable=no-member
138173
139174 def generate_design_space (self ) -> List [Schedule ]:
0 commit comments