From ff961b651e8352ae62e34c838a5982ee9f82e9b3 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 11:51:36 -0700 Subject: [PATCH 01/10] Add clone function to the schedule rule family. --- include/tvm/meta_schedule/schedule_rule.h | 17 +++++++++++++ .../schedule_rule/schedule_rule.py | 24 ++++++++++++++++++- .../schedule_rule/add_rfactor.cc | 10 ++++++++ src/meta_schedule/schedule_rule/auto_bind.cc | 9 +++++++ .../schedule_rule/auto_inline.cc | 13 ++++++++++ .../schedule_rule/cross_thread_reduction.cc | 9 +++++++ .../schedule_rule/multi_level_tiling.cc | 17 +++++++++++++ .../schedule_rule/multi_level_tiling.h | 3 +++ .../parallel_vectorize_unroll.cc | 12 ++++++++++ .../schedule_rule/random_compute_location.cc | 6 +++++ .../schedule_rule/schedule_rule.cc | 9 +++++++ 11 files changed, 128 insertions(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2da441c95e0b..1a0b2e773466 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -34,6 +34,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class ScheduleRule; /*! \brief Rules to modify a block in a schedule. */ class ScheduleRuleNode : public runtime::Object { @@ -59,6 +60,12 @@ class ScheduleRuleNode : public runtime::Object { virtual runtime::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; + /*! + * \brief Deep clone the schedule rule. + * \return The cloned schedule rule. + */ + virtual ScheduleRule Clone() = 0; + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); }; @@ -84,6 +91,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode { * \return The string of the schedule rule. */ using FAsString = runtime::TypedPackedFunc; + /*! + * \brief The function type of `Clone` method. + * \return The cloned schedule rule. + */ + using FClone = runtime::TypedPackedFunc; /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; @@ -91,15 +103,19 @@ class PyScheduleRuleNode : public ScheduleRuleNode { FApply f_apply; /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; void VisitAttrs(tvm::AttrVisitor* v) { // `f_initialize_with_tune_context` is not visited // `f_apply` is not visited // `f_as_string` is not visited + // `f_clone` is not visited } void InitializeWithTuneContext(const TuneContext& context) final; Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ScheduleRule Clone() final; static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); @@ -255,6 +271,7 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static ScheduleRule PyScheduleRule( PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FClone f_clone, // PyScheduleRuleNode::FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 481444341b86..665f99adb8de 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -66,6 +66,16 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: self, sch, block ) + def clone(self) -> "ScheduleRule": + """Deep clone the schedule rule. + + Returns + ------- + cloned_rule : ScheduleRule + The cloned schedule rule. + """ + return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyScheduleRule") class _PyScheduleRule(ScheduleRule): @@ -80,6 +90,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -88,6 +99,7 @@ def __init__( _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -102,7 +114,7 @@ class PyScheduleRule: _tvm_metadata = { "cls": _PyScheduleRule, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -136,6 +148,16 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: self, sch, block ) + def clone(self) -> ScheduleRule: + """Deep clone the schedule rule. + + Returns + ------- + cloned_rule : ScheduleRule + The cloned schedule rule. + """ + return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member + def __str__(self) -> str: """Get the schedule rule as string with name. diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index cf87f24ac233..0a63d42f65af 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,6 +36,16 @@ class AddRFactorNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = make_object(*this); + n->max_jobs_per_core = this->max_jobs_per_core; + n->max_innermost_factor = this->max_innermost_factor; + n->max_parallel_extent_ = this->max_parallel_extent_; + n->max_parallel_basic_ = this->max_parallel_basic_; + return ScheduleRule(n); + } + public: /*! * \brief The maximum number of jobs to be launched per core. diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index d8f52fa8e1de..b2c85e4d9ad3 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -177,6 +177,15 @@ class AutoBindNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = make_object(*this); + n->max_threads_per_block_ = this->max_threads_per_block_; + n->max_threadblocks_ = this->max_threadblocks_; + n->thread_extents_ = this->thread_extents_; + return ScheduleRule(n); + } + public: /*! \brief The max number of threads per block from Target */ int64_t max_threads_per_block_ = -1; diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 446c8ead7e8e..44f8d4783f0a 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -60,6 +60,19 @@ class AutoInlineNode : public ScheduleRuleNode { return {sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = make_object(*this); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op = disallow_op; + return ScheduleRule(n); + } + public: /*! \brief If allows to inline a block into its producer */ bool into_producer; diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 35be33f72e21..1551cd019d71 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -113,6 +113,15 @@ class CrossThreadReductionNode : public ScheduleRuleNode { return {tmp_sch, sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = make_object(*this); + n->thread_extents = thread_extents; + n->max_threads_per_block = max_threads_per_block; + n->warp_size = warp_size; + return ScheduleRule(n); + } + private: /*! * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index c126c854462c..ab0024b84e1a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -104,6 +104,23 @@ Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& return results; } +// Inherited from ScheduleRuleNode +ScheduleRule MultiLevelTilingNode::Clone() { + ObjectPtr n = make_object(*this); + n->structure = this->structure; + n->tile_binds = this->tile_binds; + n->max_innermost_factor = this->max_innermost_factor; + n->vector_load_lens = this->vector_load_lens; + n->reuse_read_ = this->reuse_read_; + n->reuse_write_ = this->reuse_write_; + n->s_indices_ = this->s_indices_; + n->r_indices_ = this->r_indices_; + n->thread_warp_size_ = this->thread_warp_size_; + n->max_threads_per_block_ = this->max_threads_per_block_; + n->logging_func = this->logging_func; + return ScheduleRule(n); +} + std::vector MultiLevelTilingNode::ApplySubRules(std::vector states) { states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 9161a972c187..dc8d6de5d274 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -155,6 +155,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { // Entry of the mega rule; Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() override; + protected: virtual std::vector ApplySubRules(std::vector states); diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 19758996e608..a87aed578882 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -79,6 +79,18 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { return {sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = + make_object(*this); + n->max_jobs_per_core = this->max_jobs_per_core; + n->max_vectorize_extent = this->max_vectorize_extent; + n->unroll_max_steps = this->unroll_max_steps; + n->unroll_explicit = this->unroll_explicit; + n->max_parallel_extent_ = this->max_parallel_extent_; + return ScheduleRule(n); + } + public: /*! * \brief The maximum number of jobs to be launched per CPU core. It sets the diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 65988dfd5688..5774ce0a1e19 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -57,6 +57,12 @@ class RandomComputeLocationNode : public ScheduleRuleNode { return {res}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + private: bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 80f8725b0c0d..7a452f68482f 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -33,13 +33,20 @@ Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, return f_apply(sch, block); } +ScheduleRule PyScheduleRuleNode::Clone() { + ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!"; + return f_clone(); +} + ScheduleRule ScheduleRule::PyScheduleRule( PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FClone f_clone, // PyScheduleRuleNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return ScheduleRule(n); } @@ -60,6 +67,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone") + .set_body_method(&ScheduleRuleNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") .set_body_typed(ScheduleRule::PyScheduleRule); From f6cec00e962cb9537a15e5d4d8e5e4583cdaf690 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 12:13:30 -0700 Subject: [PATCH 02/10] Fix linting. --- include/tvm/meta_schedule/schedule_rule.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1a0b2e773466..ac1cff6b6089 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -265,6 +265,7 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The schedule rule created. */ From 99408cb7dc608525598d5864ea3c3733afbb8227 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 15:25:15 -0700 Subject: [PATCH 03/10] Fix incomplete class. --- include/tvm/meta_schedule/schedule_rule.h | 78 ++++++++++++----------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index ac1cff6b6089..9d04d7cdb126 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -70,8 +70,11 @@ class ScheduleRuleNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); }; -/*! \brief The schedule rule with customized methods on the python-side. */ -class PyScheduleRuleNode : public ScheduleRuleNode { +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -96,37 +99,6 @@ class PyScheduleRuleNode : public ScheduleRuleNode { * \return The cloned schedule rule. */ using FClone = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - /*! \brief The packed function to the `Clone` function. */ - FClone f_clone; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - // `f_clone` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; - ScheduleRule Clone() final; - - static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); -}; - -/*! - * \brief Managed reference to ScheduleRuleNode - * \sa ScheduleRuleNode - */ -class ScheduleRule : public runtime::ObjectRef { - public: /*! * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions * \param into_producer If allows to inline a block into its producer @@ -270,13 +242,45 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created. */ TVM_DLL static ScheduleRule PyScheduleRule( - PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyScheduleRuleNode::FApply f_apply, // - PyScheduleRuleNode::FClone f_clone, // - PyScheduleRuleNode::FAsString f_as_string); + FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext; + using FApply = ScheduleRule::FApply; + using FClone = ScheduleRule::FClone; + using FAsString = ScheduleRule::FAsString; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ScheduleRule Clone() final; + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + } // namespace meta_schedule } // namespace tvm From c967b026bfd132ad575555826e08687f464289b0 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 15:53:26 -0700 Subject: [PATCH 04/10] Add clone function for the postproc family. --- include/tvm/meta_schedule/postproc.h | 86 ++++++++++++------- python/tvm/meta_schedule/postproc/postproc.py | 24 +++++- .../postproc/disallow_dynamic_loop.cc | 5 ++ src/meta_schedule/postproc/postproc.cc | 8 ++ .../postproc/rewrite_cooperative_fetch.cc | 5 ++ src/meta_schedule/postproc/rewrite_layout.cc | 5 ++ .../rewrite_parallel_vectorize_unroll.cc | 6 ++ .../postproc/rewrite_reduction_block.cc | 5 ++ .../postproc/rewrite_tensorize.cc | 5 ++ .../postproc/rewrite_unbound_block.cc | 5 ++ src/meta_schedule/postproc/verify_gpu_code.cc | 6 ++ 11 files changed, 126 insertions(+), 34 deletions(-) diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 5d99f6845463..4fafb9557631 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -29,6 +29,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class Postproc; /*! * \brief Rules to apply a postprocessor to a schedule. @@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object { */ virtual bool Apply(const tir::Schedule& sch) = 0; + /*! + * \brief Clone the postprocessor. + * \return The cloned postprocessor. + */ + virtual Postproc Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); }; -/*! \brief The postprocessor with customized methods on the python-side. */ -class PyPostprocNode : public PostprocNode { +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode { * \return Whether the postprocessor was successfully applied. */ using FApply = runtime::TypedPackedFunc; + /*! + * \brief Clone the postprocessor. + * \return The cloned postprocessor. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Get the postprocessor function as string with name. * \return The string of the postprocessor function. */ using FAsString = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - bool Apply(const tir::Schedule& sch) final; - - static constexpr const char* _type_key = "meta_schedule.PyPostproc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); -}; - -/*! - * \brief Managed reference to PostprocNode - * \sa PostprocNode - */ -class Postproc : public runtime::ObjectRef { - public: /*! * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The postprocessor created. */ - TVM_DLL static Postproc PyPostproc( - PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyPostprocNode::FApply f_apply, // - PyPostprocNode::FAsString f_as_string); + TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); /*! * \brief Create a postprocessor that checks if all loops are static * \return The postprocessor created @@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; +/*! \brief The postprocessor with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext; + using FApply = Postproc::FApply; + using FClone = Postproc::FClone; + using FAsString = Postproc::FAsString; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_clone` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index e37666bd1ce0..6eec2965ceeb 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -60,6 +60,16 @@ def apply(self, sch: Schedule) -> bool: """ return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member + def clone(self) -> "Postproc": + """Clone the postprocessor. + + Returns + ------- + cloned_postproc : Postproc + The cloned postprocessor. + """ + return _ffi_api.PostprocClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyPostproc") class _PyPostproc(Postproc): @@ -74,6 +84,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -82,6 +93,7 @@ def __init__( _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -96,7 +108,7 @@ class PyPostproc: _tvm_metadata = { "cls": _PyPostproc, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -124,6 +136,16 @@ def apply(self, sch: Schedule) -> bool: """ raise NotImplementedError + def clone(self) -> Postproc: + """Clone the postprocessor. + + Returns + ------- + cloned_postproc : Postproc + The cloned postprocessor. + """ + raise NotImplementedError + def __str__(self) -> str: """Get the post processor as string with name. diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 85a81f10fdcd..8362da552ea5 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -67,6 +67,11 @@ class DisallowDynamicLoopNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + // Inherited from PostprocNode + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 0f4f1b1192f6..957d6e7364e4 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -32,13 +32,20 @@ bool PyPostprocNode::Apply(const tir::Schedule& sch) { return f_apply(sch); } +Postproc PyPostprocNode::Clone() const { + ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!"; + return f_clone(); +} + Postproc Postproc::PyPostproc( PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyPostprocNode::FApply f_apply, // + PyPostprocNode::FClone f_clone, // PyPostprocNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return Postproc(n); } @@ -58,6 +65,7 @@ TVM_REGISTER_NODE_TYPE(PyPostprocNode); TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") .set_body_method(&PostprocNode::InitializeWithTuneContext); TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d111bdb42abb..ac9f45ca8ef4 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -104,6 +104,11 @@ class RewriteCooperativeFetchNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index f4cbdfe737fb..6ff9958c791f 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -167,6 +167,11 @@ class RewriteLayoutNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.RewriteLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode); }; diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 08d25d017840..c3cc0ef60152 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -384,6 +384,12 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { return true; } + Postproc Clone() const { + ObjectPtr n = + make_object(*this); + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); }; diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index ea204e306133..05a7640f047c 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -114,6 +114,11 @@ class RewriteReductionBlockNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 3b6c438d0216..4f8e0fb213f8 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -68,6 +68,11 @@ class RewriteTensorizeNode : public PostprocNode { void VisitAttrs(tvm::AttrVisitor* v) {} + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + bool vectorize_init_loop = false; static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index eb57e90f82f6..1ba68538ea04 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -97,6 +97,11 @@ class RewriteUnboundBlockNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + public: /*! \brief The max number of threads per block from Target */ int max_threads_per_block_ = -1; diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index dfe2c5a06a17..0828ee538427 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -196,6 +196,12 @@ class VerifyGPUCodeNode : public PostprocNode { return true; } + Postproc Clone() const { + ObjectPtr n = make_object(*this); + n->target_constraints_ = this->target_constraints_; + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); }; From a33425edf842bd947e5b82bab29a9f4a34708e2d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 16:07:45 -0700 Subject: [PATCH 05/10] Change clone function to const and remove extra for schedule rules. --- include/tvm/meta_schedule/schedule_rule.h | 4 ++-- src/meta_schedule/schedule_rule/add_rfactor.cc | 6 +----- src/meta_schedule/schedule_rule/auto_bind.cc | 5 +---- src/meta_schedule/schedule_rule/auto_inline.cc | 9 +-------- .../schedule_rule/cross_thread_reduction.cc | 5 +---- .../schedule_rule/multi_level_tiling.cc | 13 +------------ .../schedule_rule/multi_level_tiling.h | 2 +- .../schedule_rule/multi_level_tiling_tensor_core.cc | 7 +++++++ .../schedule_rule/multi_level_tiling_with_intrin.cc | 7 +++++++ .../schedule_rule/parallel_vectorize_unroll.cc | 7 +------ .../schedule_rule/random_compute_location.cc | 2 +- src/meta_schedule/schedule_rule/schedule_rule.cc | 2 +- 12 files changed, 25 insertions(+), 44 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 9d04d7cdb126..55704cf4a97d 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -64,7 +64,7 @@ class ScheduleRuleNode : public runtime::Object { * \brief Deep clone the schedule rule. * \return The cloned schedule rule. */ - virtual ScheduleRule Clone() = 0; + virtual ScheduleRule Clone() const = 0; static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); @@ -275,7 +275,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; - ScheduleRule Clone() final; + ScheduleRule Clone() const final; static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 0a63d42f65af..2fc1352677cb 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -37,12 +37,8 @@ class AddRFactorNode : public ScheduleRuleNode { Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); - n->max_jobs_per_core = this->max_jobs_per_core; - n->max_innermost_factor = this->max_innermost_factor; - n->max_parallel_extent_ = this->max_parallel_extent_; - n->max_parallel_basic_ = this->max_parallel_basic_; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index b2c85e4d9ad3..7af1418d8f3e 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -178,11 +178,8 @@ class AutoBindNode : public ScheduleRuleNode { Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); - n->max_threads_per_block_ = this->max_threads_per_block_; - n->max_threadblocks_ = this->max_threadblocks_; - n->thread_extents_ = this->thread_extents_; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 44f8d4783f0a..dcdc83f95cb1 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -61,15 +61,8 @@ class AutoInlineNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); - n->into_producer = into_producer; - n->into_consumer = into_consumer; - n->inline_const_tensor = inline_const_tensor; - n->disallow_if_then_else = disallow_if_then_else; - n->require_injective = require_injective; - n->require_ordered = require_ordered; - n->disallow_op = disallow_op; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 1551cd019d71..f2fc67f74cc7 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -114,11 +114,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); - n->thread_extents = thread_extents; - n->max_threads_per_block = max_threads_per_block; - n->warp_size = warp_size; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index ab0024b84e1a..1625a27b9aaf 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -105,19 +105,8 @@ Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& } // Inherited from ScheduleRuleNode -ScheduleRule MultiLevelTilingNode::Clone() { +ScheduleRule MultiLevelTilingNode::Clone() const { ObjectPtr n = make_object(*this); - n->structure = this->structure; - n->tile_binds = this->tile_binds; - n->max_innermost_factor = this->max_innermost_factor; - n->vector_load_lens = this->vector_load_lens; - n->reuse_read_ = this->reuse_read_; - n->reuse_write_ = this->reuse_write_; - n->s_indices_ = this->s_indices_; - n->r_indices_ = this->r_indices_; - n->thread_warp_size_ = this->thread_warp_size_; - n->max_threads_per_block_ = this->max_threads_per_block_; - n->logging_func = this->logging_func; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index dc8d6de5d274..47da878c3be0 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -156,7 +156,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; // Inherited from ScheduleRuleNode - ScheduleRule Clone() override; + ScheduleRule Clone() const override; protected: virtual std::vector ApplySubRules(std::vector states); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7ddda9b2635b..13b00fa7deb6 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -137,6 +137,13 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Override Apply to apply tensorization-specific analysis before applying sub-rules Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + /*! * \brief Transform and tensorize with the given tensor intrin * \param state The state of the meta schedule rule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 3a299ed041e2..b953d1ad4b50 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -63,6 +63,13 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { return res; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then // tile the outerloops. virtual std::vector ApplySubRules(std::vector states) { diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index a87aed578882..045aa85b73ad 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -80,14 +80,9 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); - n->max_jobs_per_core = this->max_jobs_per_core; - n->max_vectorize_extent = this->max_vectorize_extent; - n->unroll_max_steps = this->unroll_max_steps; - n->unroll_explicit = this->unroll_explicit; - n->max_parallel_extent_ = this->max_parallel_extent_; return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 5774ce0a1e19..7796eddd44d3 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -58,7 +58,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ScheduleRule Clone() final { + ScheduleRule Clone() const final { ObjectPtr n = make_object(*this); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 7a452f68482f..416b43f46d56 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -33,7 +33,7 @@ Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, return f_apply(sch, block); } -ScheduleRule PyScheduleRuleNode::Clone() { +ScheduleRule PyScheduleRuleNode::Clone() const { ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!"; return f_clone(); } From da3e421dddc96ea430aeba067e534fbdd9646199 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 16:20:47 -0700 Subject: [PATCH 06/10] Add clone func for the mutator family. --- include/tvm/meta_schedule/mutator.h | 88 ++++++++++++------- python/tvm/meta_schedule/mutator/mutator.py | 24 ++++- .../mutator/mutate_compute_location.cc | 5 ++ src/meta_schedule/mutator/mutate_parallel.cc | 5 ++ .../mutator/mutate_thread_binding.cc | 5 ++ src/meta_schedule/mutator/mutate_tile_size.cc | 5 ++ src/meta_schedule/mutator/mutate_unroll.cc | 5 ++ src/meta_schedule/mutator/mutator.cc | 8 ++ 8 files changed, 110 insertions(+), 35 deletions(-) diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 566cc82e9716..2b580e75e019 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -32,6 +32,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class Mutator; /*! \brief Mutator is designed to mutate the trace to explore the design space. */ class MutatorNode : public runtime::Object { @@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object { virtual Optional Apply(const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; + /*! + * \brief Clone the mutator. + * \return The cloned mutator. + */ + virtual Mutator Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); }; -/*! \brief The mutator with customized methods on the python-side. */ -class PyMutatorNode : public MutatorNode { +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode { */ using FApply = runtime::TypedPackedFunc( const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + /*! + * \brief Clone the mutator. + * \return The cloned mutator. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Get the mutator as string with name. * \return The string of the mutator. */ using FAsString = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final; - - static constexpr const char* _type_key = "meta_schedule.PyMutator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); -}; - -/*! - * \brief Managed reference to MutatorNode - * \sa MutatorNode - */ -class Mutator : public runtime::ObjectRef { - public: /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! @@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef { * \brief Create a mutator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The mutator created. */ - TVM_DLL static Mutator PyMutator( - PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyMutatorNode::FApply f_apply, // - PyMutatorNode::FAsString f_as_string); + TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); }; +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext; + using FApply = Mutator::FApply; + using FClone = Mutator::FClone; + using FAsString = Mutator::FAsString; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_clone` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final; + Mutator Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 0c8de9668034..c5286aced7d8 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -58,6 +58,16 @@ def apply(self, trace: Trace) -> Optional[Trace]: """ return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member + def clone(self) -> "Mutator": + """Clone the mutator. + + Returns + ------- + mutator : Mutator + The cloned mutator. + """ + return _ffi_api.MutatorClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyMutator") class _PyMutator(Mutator): @@ -72,6 +82,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -80,6 +91,7 @@ def __init__( _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -94,7 +106,7 @@ class PyMutator: _tvm_metadata = { "cls": _PyMutator, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -122,6 +134,16 @@ def apply(self, trace: Trace, _) -> Optional[Trace]: """ raise NotImplementedError + def clone(self) -> Mutator: + """Clone the mutator. + + Returns + ------- + mutator : Mutator + The cloned mutator. + """ + raise NotImplementedError + def __str__(self) -> str: """Get the mutator as string with name. diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 9d6d69ba355f..2a31d2da9b53 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -42,6 +42,11 @@ class MutateComputeLocationNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } private: struct Candidate { diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 82b91da682c6..9feb4747d807 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -188,6 +188,11 @@ class MutateParallelNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! \brief The candidate to be mutated */ diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index de780b53e2d9..f5d89a85092b 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -42,6 +42,11 @@ class MutateThreadBindingNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } private: struct Candidate { diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 4a3bfda8a4a8..8fb83147ea7b 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -63,6 +63,11 @@ class MutateTileSizeNode : public MutatorNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index c282a171c3b7..7bbf00343af3 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -60,6 +60,11 @@ class MutateUnrollNode : public MutatorNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! \brief A candidate to be mutated */ diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 43b95000c71d..25312ab61f99 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -33,13 +33,20 @@ Optional PyMutatorNode::Apply( return f_apply(trace, *rand_state); } +Mutator PyMutatorNode::Clone() const { + ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!"; + return f_clone(); +} + Mutator Mutator::PyMutator( PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyMutatorNode::FApply f_apply, // + PyMutatorNode::FClone f_clone, // PyMutatorNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return Mutator(n); } @@ -63,6 +70,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); }); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); } // namespace meta_schedule From 80d150823b7e8ae59e6399c9b45981f02a7bd10c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 16:23:57 -0700 Subject: [PATCH 07/10] Fix PyScheduleRule default func to not implemented. --- .../tvm/meta_schedule/schedule_rule/schedule_rule.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 665f99adb8de..2c8e223611aa 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -125,9 +125,7 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: context : TuneContext The tuning context for initializing the schedule rule. """ - _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member - self, context - ) + raise NotImplementedError def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: """Apply a schedule rule to the specific block in the given schedule. @@ -144,9 +142,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: design_spaces : List[Schedule] The list of schedules generated by applying the schedule rule. """ - return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member - self, sch, block - ) + raise NotImplementedError def clone(self) -> ScheduleRule: """Deep clone the schedule rule. @@ -156,7 +152,7 @@ def clone(self) -> ScheduleRule: cloned_rule : ScheduleRule The cloned schedule rule. """ - return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member + raise NotImplementedError def __str__(self) -> str: """Get the schedule rule as string with name. From b2d5a9b717cf725f152a14ba6b5dd0aff7b6414c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 16:41:43 -0700 Subject: [PATCH 08/10] Add clone func for the space generator family. --- include/tvm/meta_schedule/space_generator.h | 78 ++++++++++++------- .../space_generator/space_generator.py | 24 +++++- .../space_generator/post_order_apply.cc | 9 +++ .../space_generator/schedule_fn.cc | 5 ++ .../space_generator/space_generator.cc | 12 ++- .../space_generator/space_generator_union.cc | 8 ++ 6 files changed, 104 insertions(+), 32 deletions(-) diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 2c1b2d4e4d7d..1e29e757a15c 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -31,6 +31,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class SpaceGenerator; /*! * \brief The abstract class for design space generation. @@ -87,12 +88,21 @@ class SpaceGeneratorNode : public runtime::Object { */ virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + /*! + * \brief Clone the space generator. + * \return The cloned space generator. + */ + virtual SpaceGenerator Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.SpaceGenerator"; TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object); }; -/*! \brief The design space generator with customized methods on the python-side. */ -class PySpaceGeneratorNode : public SpaceGeneratorNode { +/*! + * \brief Managed reference to SpaceGeneratorNode. + * \sa SpaceGeneratorNode + */ +class SpaceGenerator : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -105,29 +115,12 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { * \return The generated design spaces, i.e., schedules. */ using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; + /*! + * \brief The function type of `Clone` method. + * \return The cloned space generator. + */ + using FClone = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `GenerateDesignSpace` function. */ - FGenerateDesignSpace f_generate_design_space; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_generate_design_space` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Array GenerateDesignSpace(const IRModule& mod) final; - - static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); -}; - -/*! - * \brief Managed reference to SpaceGeneratorNode. - * \sa SpaceGeneratorNode - */ -class SpaceGenerator : public runtime::ObjectRef { protected: SpaceGenerator() = default; @@ -136,11 +129,12 @@ class SpaceGenerator : public runtime::ObjectRef { * \brief Create a design space generator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_generate_design_space The packed function of `GenerateDesignSpace`. + * \param f_clone The packed function of `Clone`. * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, - PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space); + FInitializeWithTuneContext f_initialize_with_tune_context, + FGenerateDesignSpace f_generate_design_space, FClone f_clone); /*! * \brief Create a design space generator with customized schedule function. * \param schedule_fn The schedule function, which can have the following signatures: @@ -156,14 +150,40 @@ class SpaceGenerator : public runtime::ObjectRef { */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); /*! - * \brief Create a design space generator that generates design spaces by applying schedule rules - * to blocks in post-DFS order. - * \return The design space generator created. + * \brief Create a design space generator that generates design spaces by applying schedule + * rules to blocks in post-DFS order. \return The design space generator created. */ TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; +/*! \brief The design space generator with customized methods on the python-side. */ +class PySpaceGeneratorNode : public SpaceGeneratorNode { + public: + using FInitializeWithTuneContext = SpaceGenerator::FInitializeWithTuneContext; + using FGenerateDesignSpace = SpaceGenerator::FGenerateDesignSpace; + using FClone = SpaceGenerator::FClone; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `GenerateDesignSpace` function. */ + FGenerateDesignSpace f_generate_design_space; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_generate_design_space` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Array GenerateDesignSpace(const IRModule& mod) final; + SpaceGenerator Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 9d7ebf3bae26..23c0361645b5 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -72,6 +72,16 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: """ return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member + def clone(self) -> "SpaceGenerator": + """Clone the design space generator. + + Returns + ------- + cloned_sg : SpaceGenerator + The cloned design space generator. + """ + return _ffi_api.SpaceGeneratorClone(self) # type: ignore # pylint: disable=no-member + ScheduleFnType = SpaceGenerator.ScheduleFnType @@ -89,6 +99,7 @@ def __init__( self, f_initialize_with_tune_context: Optional[Callable] = None, f_generate_design_space: Optional[Callable] = None, + f_clone: Optional[Callable] = None, ): """Constructor.""" @@ -96,6 +107,7 @@ def __init__( _ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_generate_design_space, + f_clone, ) @@ -109,7 +121,7 @@ class PySpaceGenerator: _tvm_metadata = { "cls": _PySpaceGenerator, - "methods": ["_initialize_with_tune_context", "generate_design_space"], + "methods": ["_initialize_with_tune_context", "generate_design_space", "clone"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -137,6 +149,16 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: """ raise NotImplementedError + def clone(self) -> SpaceGenerator: + """Clone the design space generator. + + Returns + ------- + cloned_sg : SpaceGenerator + The cloned design space generator. + """ + raise NotImplementedError + def create( # pylint: disable=keyword-arg-before-vararg kind: Union[ diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 9be89e2d9c70..991e4fa08047 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -188,6 +188,15 @@ class PostOrderApplyNode : public SpaceGeneratorNode { } return result; } + + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + n->sch_rules_ = Array(); + for (const ScheduleRule& sch_rule : this->sch_rules_) { + n->sch_rules_.push_back(sch_rule->Clone()); + } + return SpaceGenerator(n); + } static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 70559fbcf1fb..adea139b1cd4 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -72,6 +72,11 @@ class ScheduleFnNode : public SpaceGeneratorNode { throw; } + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + return SpaceGenerator(n); + } + static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 5c5ab6ebbae5..6fc31ed896f2 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -33,12 +33,18 @@ Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& m return f_generate_design_space(mod); } +SpaceGenerator PySpaceGeneratorNode::Clone() const { + ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!"; + return f_clone(); +} + SpaceGenerator SpaceGenerator::PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, - PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) { + FInitializeWithTuneContext f_initialize_with_tune_context, + FGenerateDesignSpace f_generate_design_space, FClone f_clone) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_generate_design_space = std::move(f_generate_design_space); + n->f_clone = std::move(f_clone); return SpaceGenerator(n); } @@ -51,6 +57,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") .set_body_typed(SpaceGenerator::PySpaceGenerator); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") + .set_body_method(&SpaceGeneratorNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 6ea61824f932..27e4beecde14 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -47,6 +47,14 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { return design_spaces; } + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + n->space_generators = Array(); + for (const SpaceGenerator& space_generator : this->space_generators) { + n->space_generators.push_back(space_generator->Clone()); + } + } + static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode); }; From a921e3640bb659e09d3411c97f5ac085fa350e7b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 17:04:01 -0700 Subject: [PATCH 09/10] Add clone func for search strategy family. --- include/tvm/meta_schedule/search_strategy.h | 114 +++++++++++------- .../search_strategy/search_strategy.py | 23 ++++ .../search_strategy/evolutionary_search.cc | 18 +++ .../search_strategy/replay_func.cc | 10 ++ .../search_strategy/replay_trace.cc | 11 ++ .../search_strategy/search_strategy.cc | 11 +- 6 files changed, 141 insertions(+), 46 deletions(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index a75a4cd8ae86..efd3dc24524a 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -36,6 +36,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class SearchStrategy; /*! * \brief The search strategy for measure candidates generation. @@ -119,12 +120,21 @@ class SearchStrategyNode : public runtime::Object { virtual void NotifyRunnerResults(const Array& measure_candidates, const Array& results) = 0; + /*! + * \brief Clone the search strategy. + * \return The cloned search strategy. + */ + virtual SearchStrategy Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); }; -/*! \brief The python side customizable class for measure candidate generation */ -class PySearchStrategyNode : public SearchStrategyNode { +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -150,44 +160,11 @@ class PySearchStrategyNode : public SearchStrategyNode { */ using FNotifyRunnerResults = runtime::TypedPackedFunc&, const Array&)>; - - /*! \brief The packed function to the `InitializeWithTuneContext` method. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `PreTuning` method. */ - FPreTuning f_pre_tuning; - /*! \brief The packed function to the `PostTuning` method. */ - FPostTuning f_post_tuning; - /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ - FGenerateMeasureCandidates f_generate_measure_candidates; - /*! \brief The packed function to the `NotifyRunnerResults` method. */ - FNotifyRunnerResults f_notify_runner_results; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_pre_tuning` is not visited - // `f_post_tuning` is not visited - // `f_generate_measure_candidates` is not visited - // `f_notify_runner_results` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final; - void PostTuning() final; - Optional> GenerateMeasureCandidates() final; - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); - - static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); -}; - -/*! - * \brief Managed reference to SearchStrategyNode. - * \sa SearchStrategyNode - */ -class SearchStrategy : public runtime::ObjectRef { - public: + /*! + * \brief The function type of `Clone` method. + * \return The cloned search strategy. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Create a search strategy with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -195,14 +172,16 @@ class SearchStrategy : public runtime::ObjectRef { * \param f_post_tuning The packed function of `PostTuning`. * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \param f_clone The packed function of `Clone`. * \return The search strategy created. */ TVM_DLL static SearchStrategy PySearchStrategy( - PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PySearchStrategyNode::FPreTuning f_pre_tuning, // - PySearchStrategyNode::FPostTuning f_post_tuning, // - PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // - PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + FInitializeWithTuneContext f_initialize_with_tune_context, // + FPreTuning f_pre_tuning, // + FPostTuning f_post_tuning, // + FGenerateMeasureCandidates f_generate_measure_candidates, // + FNotifyRunnerResults f_notify_runner_results, // + FClone f_clone); /*! * \brief Constructor of replay trace search strategy. @@ -245,6 +224,51 @@ class SearchStrategy : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext; + using FPreTuning = SearchStrategy::FPreTuning; + using FPostTuning = SearchStrategy::FPostTuning; + using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates; + using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults; + using FClone = SearchStrategy::FClone; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + /*! \brief The packed function to the `Clone` method. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final; + void PostTuning() final; + Optional> GenerateMeasureCandidates() final; + void NotifyRunnerResults(const Array& measure_candidates, + const Array& results); + SearchStrategy Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index e88cdf825a79..276e65713325 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -151,6 +151,16 @@ def notify_runner_results( results, ) + def clone(self) -> "SearchStrategy": + """Clone the search strategy. + + Returns + ------- + cloned : SearchStrategy + The cloned search strategy. + """ + return _ffi_api.SearchStrategyClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PySearchStrategy") class _PySearchStrategy(SearchStrategy): @@ -168,6 +178,7 @@ def __init__( f_post_tuning: Callable = None, f_generate_measure_candidates: Callable = None, f_notify_runner_results: Callable = None, + f_clone: Callable = None, ): """Constructor.""" @@ -178,6 +189,7 @@ def __init__( f_post_tuning, f_generate_measure_candidates, f_notify_runner_results, + f_clone, ) @@ -197,6 +209,7 @@ class PySearchStrategy: "post_tuning", "generate_measure_candidates", "notify_runner_results", + "clone", ], } @@ -250,6 +263,16 @@ def notify_runner_results( """ raise NotImplementedError + def clone(self) -> SearchStrategy: + """Clone the search strategy. + + Returns + ------- + strategy : SearchStrategy + The cloned search strategy. + """ + raise NotImplementedError + def create( # pylint: disable=keyword-arg-before-vararg kind: Literal[ diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index c5ff9008effe..5930704eb0d1 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -431,6 +431,24 @@ class EvolutionarySearchNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(measure_candidates, results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->max_trials_per_task = this->max_trials_per_task; + n->num_trials_per_iter = this->num_trials_per_iter; + n->population_size = this->population_size; + n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; + n->init_measured_ratio = this->init_measured_ratio; + n->init_min_unmeasured = this->init_min_unmeasured; + n->genetic_num_iters = this->genetic_num_iters; + n->genetic_mutate_prob = this->genetic_mutate_prob; + n->genetic_max_fail_count = this->genetic_max_fail_count; + n->eps_greedy = this->eps_greedy; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 4574c1c817a8..6914ab2f0f0a 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -100,6 +100,16 @@ class ReplayFuncNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->num_trials_per_iter = this->num_trials_per_iter; + n->max_trials_per_task = this->max_trials_per_task; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 64fc68394357..bd553bf037d1 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -118,6 +118,17 @@ class ReplayTraceNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->num_trials_per_iter = this->num_trials_per_iter; + n->max_trials_per_task = this->max_trials_per_task; + n->max_fail_count = this->max_fail_count; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 5865fc842248..81c7fda315b4 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -59,18 +59,25 @@ void PySearchStrategyNode::NotifyRunnerResults(const Array& me f_notify_runner_results(measure_candidates, results); } +SearchStrategy PySearchStrategyNode::Clone() const { + ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!"; + return f_clone(); +} + SearchStrategy SearchStrategy::PySearchStrategy( PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PySearchStrategyNode::FPreTuning f_pre_tuning, // PySearchStrategyNode::FPostTuning f_post_tuning, // PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // - PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, // + PySearchStrategyNode::FClone f_clone) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = f_initialize_with_tune_context; n->f_pre_tuning = f_pre_tuning; n->f_post_tuning = f_post_tuning; n->f_generate_measure_candidates = f_generate_measure_candidates; n->f_notify_runner_results = f_notify_runner_results; + n->f_clone = f_clone; return SearchStrategy(n); } @@ -94,6 +101,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") .set_body_method(&SearchStrategyNode::NotifyRunnerResults); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") + .set_body_method(&SearchStrategyNode::Clone); } // namespace meta_schedule } // namespace tvm From 4dec23aadb1fde042729aa84dd3eb8b31fe27d1e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 15 Sep 2022 18:04:25 -0700 Subject: [PATCH 10/10] Add clone func for TuneContext. --- include/tvm/meta_schedule/tune_context.h | 6 +++++ .../tvm/meta_schedule/testing/dummy_object.py | 3 +++ python/tvm/meta_schedule/tune_context.py | 10 +++++++ .../space_generator/space_generator_union.cc | 1 + src/meta_schedule/tune_context.cc | 26 +++++++++++++++++++ 5 files changed, 46 insertions(+) diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 3d732e7fbd99..4e2f00fb5a0c 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -43,6 +43,7 @@ namespace meta_schedule { class TaskSchedulerNode; class MeasureCallback; +class TuneContext; /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { @@ -99,6 +100,11 @@ class TuneContextNode : public runtime::Object { /*! \brief Initialize members that needs initialization with tune context. */ void Initialize(); + /*! + * \brief Clone the tune context. + * \return The cloned tune context. + */ + TuneContext Clone() const; /*! \brief Set the measure candidates from the SearchStrategy */ void _SetMeasureCandidates(const Array& candidates); /*! diff --git a/python/tvm/meta_schedule/testing/dummy_object.py b/python/tvm/meta_schedule/testing/dummy_object.py index 50ae974df5d8..bb2294544920 100644 --- a/python/tvm/meta_schedule/testing/dummy_object.py +++ b/python/tvm/meta_schedule/testing/dummy_object.py @@ -58,3 +58,6 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {}) + + def clone(self): + return DummyMutator() diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 17acad8d4a57..29cd94110c0c 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -331,3 +331,13 @@ def notify_runner_results( "Please construct TuneContext with search_strategy" ) return self.search_strategy.notify_runner_results(measure_candidates, results) + + def clone(self) -> "TuneContext": + """Clone the TuneContext. + + Returns + ------- + cloned_context : TuneContext + The cloned TuneContext. + """ + return _ffi_api.TuneContextClone(self) # type: ignore # pylint: disable=no-member diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 27e4beecde14..771d0c187f97 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -53,6 +53,7 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { for (const SpaceGenerator& space_generator : this->space_generators) { n->space_generators.push_back(space_generator->Clone()); } + return SpaceGenerator(n); } static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 57b2344c6f8d..3650c0374dab 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -52,6 +52,32 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +TuneContext TuneContextNode::Clone() const { + ObjectPtr n = make_object(*this); + if (this->sch_rules.defined()) { + n->sch_rules = Array(); + for (const ScheduleRule& sch_rule : this->sch_rules) { + n->sch_rules.push_back(sch_rule->Clone()); + } + } + if (this->postprocs.defined()) { + n->postprocs = Array(); + for (const Postproc& postproc : this->postprocs) { + n->postprocs.push_back(postproc->Clone()); + } + } + if (this->mutator_probs.defined()) { + n->mutator_probs = Map(); + for (const auto& kv : this->mutator_probs) { + n->mutator_probs.Set(kv.first->Clone(), kv.second); + } + } + if (this->space_generator.defined()) n->space_generator = this->space_generator.value()->Clone(); + if (this->search_strategy.defined()) n->search_strategy = this->search_strategy.value()->Clone(); + n->Initialize(); + return TuneContext(n); +} + void TuneContextNode::Initialize() { if (this->space_generator.defined()) { this->space_generator.value()->InitializeWithTuneContext(GetRef(this));