@@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
7575 Runner runner{nullptr };
7676 /* ! \brief The database of the scheduler. */
7777 Database database{nullptr };
78+ /* ! \brief The maximum number of trials allowed. */
79+ int max_trials;
7880 /* ! \brief The cost model of the scheduler. */
7981 Optional<CostModel> cost_model;
8082 /* ! \brief The list of measure callbacks of the scheduler. */
8183 Array<MeasureCallback> measure_callbacks;
84+ /* ! \brief The number of trials already conducted. */
85+ int num_trials_already;
8286
8387 /* ! \brief The default destructor. */
8488 virtual ~TaskSchedulerNode () = default ;
@@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
8892 v->Visit (" builder" , &builder);
8993 v->Visit (" runner" , &runner);
9094 v->Visit (" database" , &database);
95+ v->Visit (" max_trials" , &max_trials);
9196 v->Visit (" cost_model" , &cost_model);
9297 v->Visit (" measure_callbacks" , &measure_callbacks);
98+ v->Visit (" num_trials_already" , &num_trials_already);
9399 }
94100
95101 /* ! \brief Auto-tuning. */
@@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
102108 virtual void InitializeTask (int task_id);
103109
104110 /* !
105- * \brief Set specific task to be stopped.
106- * \param task_id The task id to be stopped.
107- */
108- virtual void SetTaskStopped (int task_id);
109-
110- /* !
111- * \brief Check whether the task is running.
111+ * \brief Touch the task and update its status
112112 * \param task_id The task id to be checked.
113- * \return Whether the task is running.
114113 */
115- virtual bool IsTaskRunning (int task_id);
114+ virtual void TouchTask (int task_id);
116115
117116 /* !
118117 * \brief Wait until the task is finished.
119118 * \param task_id The task id to be joined.
120119 */
121- virtual void JoinRunningTask (int task_id);
120+ virtual Array<RunnerResult> JoinRunningTask (int task_id);
122121
123122 /* !
124123 * \brief Fetch the next task id.
@@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
142141 using FInitializeTask = runtime::TypedPackedFunc<void (int )>;
143142
144143 /* !
145- * \brief The function type of `SetTaskStopped` method.
146- * \param task_id The task id to be stopped.
147- */
148- using FSetTaskStopped = runtime::TypedPackedFunc<void (int )>;
149-
150- /* !
151- * \brief The function type of `IsTaskRunning` method.
144+ * \brief The function type of `TouchTask` method.
152145 * \param task_id The task id to be checked.
153146 * \return Whether the task is running.
154147 */
155- using FIsTaskRunning = runtime::TypedPackedFunc<bool (int )>;
148+ using FTouchTask = runtime::TypedPackedFunc<void (int )>;
156149
157150 /* !
158151 * \brief The function type of `JoinRunningTask` method.
159152 * \param task_id The task id to be joined.
160153 */
161- using FJoinRunningTask = runtime::TypedPackedFunc<void (int )>;
154+ using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult> (int )>;
162155
163156 /* !
164157 * \brief The function type of `NextTaskId` method.
@@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
170163 FTune f_tune;
171164 /* ! \brief The packed function to the `InitializeTask` function. */
172165 FInitializeTask f_initialize_task;
173- /* ! \brief The packed function to the `SetTaskStopped` function. */
174- FSetTaskStopped f_set_task_stopped;
175- /* ! \brief The packed function to the `IsTaskRunning` function. */
176- FIsTaskRunning f_is_task_running;
166+ /* ! \brief The packed function to the `TouchTask` function. */
167+ FTouchTask f_touch_task;
177168 /* ! \brief The packed function to the `JoinRunningTask` function. */
178169 FJoinRunningTask f_join_running_task;
179170 /* ! \brief The packed function to the `NextTaskId` function. */
@@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
182173 void VisitAttrs (tvm::AttrVisitor* v) {
183174 // `f_tune` is not visited
184175 // `f_initialize_task` is not visited
185- // `f_set_task_stopped` is not visited
186- // `f_is_task_running` is not visited
176+ // `f_touch_task` is not visited
187177 // `f_join_running_task` is not visited
188178 // `f_next_task_id` is not visited
189179 }
@@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
204194 }
205195 }
206196
207- void SetTaskStopped (int task_id) final {
208- if (f_set_task_stopped == nullptr ) {
209- TaskSchedulerNode::SetTaskStopped (task_id);
210- } else {
211- f_set_task_stopped (task_id);
212- }
213- }
214-
215- bool IsTaskRunning (int task_id) final {
216- if (f_is_task_running == nullptr ) {
217- return TaskSchedulerNode::IsTaskRunning (task_id);
197+ void TouchTask (int task_id) final {
198+ if (f_touch_task == nullptr ) {
199+ return TaskSchedulerNode::TouchTask (task_id);
218200 } else {
219- return f_is_task_running (task_id);
201+ return f_touch_task (task_id);
220202 }
221203 }
222204
223- void JoinRunningTask (int task_id) final {
205+ Array<RunnerResult> JoinRunningTask (int task_id) final {
224206 if (f_join_running_task == nullptr ) {
225207 return TaskSchedulerNode::JoinRunningTask (task_id);
226208 } else {
@@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
249231 * \param builder The builder of the scheduler.
250232 * \param runner The runner of the scheduler.
251233 * \param database The database of the scheduler.
234+ * \param max_trials The maximum number of trials.
252235 * \param cost_model The cost model of the scheduler.
253236 * \param measure_callbacks The measure callbacks of the scheduler.
254237 * \return The task scheduler created.
@@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
257240 Builder builder, //
258241 Runner runner, //
259242 Database database, //
243+ int max_trials, //
260244 Optional<CostModel> cost_model, //
261245 Optional<Array<MeasureCallback>> measure_callbacks);
246+ /* !
247+ * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
248+ * \param tasks The tasks to be tuned.
249+ * \param task_weights The weights of each task.
250+ * \param builder The builder of the scheduler.
251+ * \param runner The runner of the scheduler.
252+ * \param database The database of the scheduler.
253+ * \param max_trials The maximum number of trials.
254+ * \param cost_model The cost model of the scheduler.
255+ * \param measure_callbacks The measure callbacks of the scheduler.
256+ * \param alpha The parameter alpha to control gradient computation.
257+ * \param window_size The parameter to control backward window size.
258+ * \param seed The random seed.
259+ * \return The task scheduler created.
260+ */
261+ TVM_DLL static TaskScheduler GradientBased (Array<TuneContext> tasks,
262+ Array<FloatImm> task_weights, //
263+ Builder builder, //
264+ Runner runner, //
265+ Database database, //
266+ int max_trials, //
267+ Optional<CostModel> cost_model, //
268+ Optional<Array<MeasureCallback>> measure_callbacks, //
269+ double alpha, //
270+ int window_size, //
271+ support::LinearCongruentialEngine::TRandState seed);
262272 /* !
263273 * \brief Create a task scheduler with customized methods on the python-side.
264274 * \param tasks The tasks to be tuned.
265275 * \param builder The builder of the scheduler.
266276 * \param runner The runner of the scheduler.
267277 * \param database The database of the scheduler.
278+ * \param max_trials The maximum number of trials.
268279 * \param cost_model The cost model of the scheduler.
269280 * \param measure_callbacks The measure callbacks of the scheduler.
270281 * \param f_tune The packed function of `Tune`.
271282 * \param f_initialize_task The packed function of `InitializeTask`.
272- * \param f_set_task_stopped The packed function of `SetTaskStopped`.
273- * \param f_is_task_running The packed function of `IsTaskRunning`.
283+ * \param f_touch_task The packed function of `TouchTask`.
274284 * \param f_join_running_task The packed function of `JoinRunningTask`.
275285 * \param f_next_task_id The packed function of `NextTaskId`.
276286 * \return The task scheduler created.
@@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef {
280290 Builder builder, //
281291 Runner runner, //
282292 Database database, //
293+ int max_trials, //
283294 Optional<CostModel> cost_model, //
284295 Optional<Array<MeasureCallback>> measure_callbacks, //
285296 PyTaskSchedulerNode::FTune f_tune, //
286297 PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
287- PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
288- PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
298+ PyTaskSchedulerNode::FTouchTask f_touch_task, //
289299 PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
290300 PyTaskSchedulerNode::FNextTaskId f_next_task_id);
291301 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS (TaskScheduler, ObjectRef, TaskSchedulerNode);
0 commit comments