diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 566cc82e9716..2b580e75e019 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -32,6 +32,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class Mutator; /*! \brief Mutator is designed to mutate the trace to explore the design space. */ class MutatorNode : public runtime::Object { @@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object { virtual Optional Apply(const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; + /*! + * \brief Clone the mutator. + * \return The cloned mutator. + */ + virtual Mutator Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); }; -/*! \brief The mutator with customized methods on the python-side. */ -class PyMutatorNode : public MutatorNode { +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode { */ using FApply = runtime::TypedPackedFunc( const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + /*! + * \brief Clone the mutator. + * \return The cloned mutator. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Get the mutator as string with name. * \return The string of the mutator. */ using FAsString = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final; - - static constexpr const char* _type_key = "meta_schedule.PyMutator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); -}; - -/*! - * \brief Managed reference to MutatorNode - * \sa MutatorNode - */ -class Mutator : public runtime::ObjectRef { - public: /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! @@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef { * \brief Create a mutator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The mutator created. */ - TVM_DLL static Mutator PyMutator( - PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyMutatorNode::FApply f_apply, // - PyMutatorNode::FAsString f_as_string); + TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); }; +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext; + using FApply = Mutator::FApply; + using FClone = Mutator::FClone; + using FAsString = Mutator::FAsString; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_clone` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final; + Mutator Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 5d99f6845463..4fafb9557631 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -29,6 +29,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class Postproc; /*! * \brief Rules to apply a postprocessor to a schedule. @@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object { */ virtual bool Apply(const tir::Schedule& sch) = 0; + /*! + * \brief Clone the postprocessor. + * \return The cloned postprocessor. + */ + virtual Postproc Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); }; -/*! \brief The postprocessor with customized methods on the python-side. */ -class PyPostprocNode : public PostprocNode { +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode { * \return Whether the postprocessor was successfully applied. */ using FApply = runtime::TypedPackedFunc; + /*! + * \brief Clone the postprocessor. + * \return The cloned postprocessor. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Get the postprocessor function as string with name. * \return The string of the postprocessor function. */ using FAsString = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - bool Apply(const tir::Schedule& sch) final; - - static constexpr const char* _type_key = "meta_schedule.PyPostproc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); -}; - -/*! - * \brief Managed reference to PostprocNode - * \sa PostprocNode - */ -class Postproc : public runtime::ObjectRef { - public: /*! * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The postprocessor created. */ - TVM_DLL static Postproc PyPostproc( - PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyPostprocNode::FApply f_apply, // - PyPostprocNode::FAsString f_as_string); + TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); /*! * \brief Create a postprocessor that checks if all loops are static * \return The postprocessor created @@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; +/*! \brief The postprocessor with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext; + using FApply = Postproc::FApply; + using FClone = Postproc::FClone; + using FAsString = Postproc::FAsString; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_clone` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2da441c95e0b..55704cf4a97d 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -34,6 +34,7 @@ namespace tvm { namespace meta_schedule { class TuneContext; +class ScheduleRule; /*! \brief Rules to modify a block in a schedule. */ class ScheduleRuleNode : public runtime::Object { @@ -59,12 +60,21 @@ class ScheduleRuleNode : public runtime::Object { virtual runtime::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; + /*! + * \brief Deep clone the schedule rule. + * \return The cloned schedule rule. + */ + virtual ScheduleRule Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); }; -/*! \brief The schedule rule with customized methods on the python-side. */ -class PyScheduleRuleNode : public ScheduleRuleNode { +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -84,33 +94,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode { * \return The string of the schedule rule. */ using FAsString = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` function. */ - FApply f_apply; - /*! \brief The packed function to the `AsString` function. */ - FAsString f_as_string; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; - - static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); -}; - -/*! - * \brief Managed reference to ScheduleRuleNode - * \sa ScheduleRuleNode - */ -class ScheduleRule : public runtime::ObjectRef { - public: + /*! + * \brief The function type of `Clone` method. + * \return The cloned schedule rule. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions * \param into_producer If allows to inline a block into its producer @@ -249,16 +237,50 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_clone The packed function of `Clone`. * \param f_as_string The packed function of `AsString`. * \return The schedule rule created. */ TVM_DLL static ScheduleRule PyScheduleRule( - PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PyScheduleRuleNode::FApply f_apply, // - PyScheduleRuleNode::FAsString f_as_string); + FInitializeWithTuneContext f_initialize_with_tune_context, // + FApply f_apply, // + FClone f_clone, // + FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext; + using FApply = ScheduleRule::FApply; + using FClone = ScheduleRule::FClone; + using FAsString = ScheduleRule::FAsString; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ScheduleRule Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index a75a4cd8ae86..efd3dc24524a 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -36,6 +36,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class SearchStrategy; /*! * \brief The search strategy for measure candidates generation. @@ -119,12 +120,21 @@ class SearchStrategyNode : public runtime::Object { virtual void NotifyRunnerResults(const Array& measure_candidates, const Array& results) = 0; + /*! + * \brief Clone the search strategy. + * \return The cloned search strategy. + */ + virtual SearchStrategy Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); }; -/*! \brief The python side customizable class for measure candidate generation */ -class PySearchStrategyNode : public SearchStrategyNode { +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -150,44 +160,11 @@ class PySearchStrategyNode : public SearchStrategyNode { */ using FNotifyRunnerResults = runtime::TypedPackedFunc&, const Array&)>; - - /*! \brief The packed function to the `InitializeWithTuneContext` method. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `PreTuning` method. */ - FPreTuning f_pre_tuning; - /*! \brief The packed function to the `PostTuning` method. */ - FPostTuning f_post_tuning; - /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ - FGenerateMeasureCandidates f_generate_measure_candidates; - /*! \brief The packed function to the `NotifyRunnerResults` method. */ - FNotifyRunnerResults f_notify_runner_results; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_pre_tuning` is not visited - // `f_post_tuning` is not visited - // `f_generate_measure_candidates` is not visited - // `f_notify_runner_results` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final; - void PostTuning() final; - Optional> GenerateMeasureCandidates() final; - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); - - static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); -}; - -/*! - * \brief Managed reference to SearchStrategyNode. - * \sa SearchStrategyNode - */ -class SearchStrategy : public runtime::ObjectRef { - public: + /*! + * \brief The function type of `Clone` method. + * \return The cloned search strategy. + */ + using FClone = runtime::TypedPackedFunc; /*! * \brief Create a search strategy with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -195,14 +172,16 @@ class SearchStrategy : public runtime::ObjectRef { * \param f_post_tuning The packed function of `PostTuning`. * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \param f_clone The packed function of `Clone`. * \return The search strategy created. */ TVM_DLL static SearchStrategy PySearchStrategy( - PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // - PySearchStrategyNode::FPreTuning f_pre_tuning, // - PySearchStrategyNode::FPostTuning f_post_tuning, // - PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // - PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + FInitializeWithTuneContext f_initialize_with_tune_context, // + FPreTuning f_pre_tuning, // + FPostTuning f_post_tuning, // + FGenerateMeasureCandidates f_generate_measure_candidates, // + FNotifyRunnerResults f_notify_runner_results, // + FClone f_clone); /*! * \brief Constructor of replay trace search strategy. @@ -245,6 +224,51 @@ class SearchStrategy : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext; + using FPreTuning = SearchStrategy::FPreTuning; + using FPostTuning = SearchStrategy::FPostTuning; + using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates; + using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults; + using FClone = SearchStrategy::FClone; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + /*! \brief The packed function to the `Clone` method. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final; + void PostTuning() final; + Optional> GenerateMeasureCandidates() final; + void NotifyRunnerResults(const Array& measure_candidates, + const Array& results); + SearchStrategy Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 2c1b2d4e4d7d..1e29e757a15c 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -31,6 +31,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class SpaceGenerator; /*! * \brief The abstract class for design space generation. @@ -87,12 +88,21 @@ class SpaceGeneratorNode : public runtime::Object { */ virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + /*! + * \brief Clone the space generator. + * \return The cloned space generator. + */ + virtual SpaceGenerator Clone() const = 0; + static constexpr const char* _type_key = "meta_schedule.SpaceGenerator"; TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object); }; -/*! \brief The design space generator with customized methods on the python-side. */ -class PySpaceGeneratorNode : public SpaceGeneratorNode { +/*! + * \brief Managed reference to SpaceGeneratorNode. + * \sa SpaceGeneratorNode + */ +class SpaceGenerator : public runtime::ObjectRef { public: /*! * \brief The function type of `InitializeWithTuneContext` method. @@ -105,29 +115,12 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { * \return The generated design spaces, i.e., schedules. */ using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; + /*! + * \brief The function type of `Clone` method. + * \return The cloned space generator. + */ + using FClone = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` function. */ - FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `GenerateDesignSpace` function. */ - FGenerateDesignSpace f_generate_design_space; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_generate_design_space` is not visited - } - - void InitializeWithTuneContext(const TuneContext& context) final; - Array GenerateDesignSpace(const IRModule& mod) final; - - static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); -}; - -/*! - * \brief Managed reference to SpaceGeneratorNode. - * \sa SpaceGeneratorNode - */ -class SpaceGenerator : public runtime::ObjectRef { protected: SpaceGenerator() = default; @@ -136,11 +129,12 @@ class SpaceGenerator : public runtime::ObjectRef { * \brief Create a design space generator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_generate_design_space The packed function of `GenerateDesignSpace`. + * \param f_clone The packed function of `Clone`. * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, - PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space); + FInitializeWithTuneContext f_initialize_with_tune_context, + FGenerateDesignSpace f_generate_design_space, FClone f_clone); /*! * \brief Create a design space generator with customized schedule function. * \param schedule_fn The schedule function, which can have the following signatures: @@ -156,14 +150,40 @@ class SpaceGenerator : public runtime::ObjectRef { */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); /*! - * \brief Create a design space generator that generates design spaces by applying schedule rules - * to blocks in post-DFS order. - * \return The design space generator created. + * \brief Create a design space generator that generates design spaces by applying schedule + * rules to blocks in post-DFS order. \return The design space generator created. */ TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; +/*! \brief The design space generator with customized methods on the python-side. */ +class PySpaceGeneratorNode : public SpaceGeneratorNode { + public: + using FInitializeWithTuneContext = SpaceGenerator::FInitializeWithTuneContext; + using FGenerateDesignSpace = SpaceGenerator::FGenerateDesignSpace; + using FClone = SpaceGenerator::FClone; + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `GenerateDesignSpace` function. */ + FGenerateDesignSpace f_generate_design_space; + /*! \brief The packed function to the `Clone` function. */ + FClone f_clone; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_generate_design_space` is not visited + // `f_clone` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final; + Array GenerateDesignSpace(const IRModule& mod) final; + SpaceGenerator Clone() const final; + + static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 3d732e7fbd99..4e2f00fb5a0c 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -43,6 +43,7 @@ namespace meta_schedule { class TaskSchedulerNode; class MeasureCallback; +class TuneContext; /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { @@ -99,6 +100,11 @@ class TuneContextNode : public runtime::Object { /*! \brief Initialize members that needs initialization with tune context. */ void Initialize(); + /*! + * \brief Clone the tune context. + * \return The cloned tune context. + */ + TuneContext Clone() const; /*! \brief Set the measure candidates from the SearchStrategy */ void _SetMeasureCandidates(const Array& candidates); /*! diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 0c8de9668034..c5286aced7d8 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -58,6 +58,16 @@ def apply(self, trace: Trace) -> Optional[Trace]: """ return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member + def clone(self) -> "Mutator": + """Clone the mutator. + + Returns + ------- + mutator : Mutator + The cloned mutator. + """ + return _ffi_api.MutatorClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyMutator") class _PyMutator(Mutator): @@ -72,6 +82,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -80,6 +91,7 @@ def __init__( _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -94,7 +106,7 @@ class PyMutator: _tvm_metadata = { "cls": _PyMutator, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -122,6 +134,16 @@ def apply(self, trace: Trace, _) -> Optional[Trace]: """ raise NotImplementedError + def clone(self) -> Mutator: + """Clone the mutator. + + Returns + ------- + mutator : Mutator + The cloned mutator. + """ + raise NotImplementedError + def __str__(self) -> str: """Get the mutator as string with name. diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index e37666bd1ce0..6eec2965ceeb 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -60,6 +60,16 @@ def apply(self, sch: Schedule) -> bool: """ return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member + def clone(self) -> "Postproc": + """Clone the postprocessor. + + Returns + ------- + cloned_postproc : Postproc + The cloned postprocessor. + """ + return _ffi_api.PostprocClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyPostproc") class _PyPostproc(Postproc): @@ -74,6 +84,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -82,6 +93,7 @@ def __init__( _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -96,7 +108,7 @@ class PyPostproc: _tvm_metadata = { "cls": _PyPostproc, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -124,6 +136,16 @@ def apply(self, sch: Schedule) -> bool: """ raise NotImplementedError + def clone(self) -> Postproc: + """Clone the postprocessor. + + Returns + ------- + cloned_postproc : Postproc + The cloned postprocessor. + """ + raise NotImplementedError + def __str__(self) -> str: """Get the post processor as string with name. diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 481444341b86..2c8e223611aa 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -66,6 +66,16 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: self, sch, block ) + def clone(self) -> "ScheduleRule": + """Deep clone the schedule rule. + + Returns + ------- + cloned_rule : ScheduleRule + The cloned schedule rule. + """ + return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyScheduleRule") class _PyScheduleRule(ScheduleRule): @@ -80,6 +90,7 @@ def __init__( self, f_initialize_with_tune_context: Callable = None, f_apply: Callable = None, + f_clone: Callable = None, f_as_string: Callable = None, ): """Constructor.""" @@ -88,6 +99,7 @@ def __init__( _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_apply, + f_clone, f_as_string, ) @@ -102,7 +114,7 @@ class PyScheduleRule: _tvm_metadata = { "cls": _PyScheduleRule, - "methods": ["_initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -113,9 +125,7 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: context : TuneContext The tuning context for initializing the schedule rule. """ - _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member - self, context - ) + raise NotImplementedError def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: """Apply a schedule rule to the specific block in the given schedule. @@ -132,9 +142,17 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: design_spaces : List[Schedule] The list of schedules generated by applying the schedule rule. """ - return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member - self, sch, block - ) + raise NotImplementedError + + def clone(self) -> ScheduleRule: + """Deep clone the schedule rule. + + Returns + ------- + cloned_rule : ScheduleRule + The cloned schedule rule. + """ + raise NotImplementedError def __str__(self) -> str: """Get the schedule rule as string with name. diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index e88cdf825a79..276e65713325 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -151,6 +151,16 @@ def notify_runner_results( results, ) + def clone(self) -> "SearchStrategy": + """Clone the search strategy. + + Returns + ------- + cloned : SearchStrategy + The cloned search strategy. + """ + return _ffi_api.SearchStrategyClone(self) # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PySearchStrategy") class _PySearchStrategy(SearchStrategy): @@ -168,6 +178,7 @@ def __init__( f_post_tuning: Callable = None, f_generate_measure_candidates: Callable = None, f_notify_runner_results: Callable = None, + f_clone: Callable = None, ): """Constructor.""" @@ -178,6 +189,7 @@ def __init__( f_post_tuning, f_generate_measure_candidates, f_notify_runner_results, + f_clone, ) @@ -197,6 +209,7 @@ class PySearchStrategy: "post_tuning", "generate_measure_candidates", "notify_runner_results", + "clone", ], } @@ -250,6 +263,16 @@ def notify_runner_results( """ raise NotImplementedError + def clone(self) -> SearchStrategy: + """Clone the search strategy. + + Returns + ------- + strategy : SearchStrategy + The cloned search strategy. + """ + raise NotImplementedError + def create( # pylint: disable=keyword-arg-before-vararg kind: Literal[ diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 9d7ebf3bae26..23c0361645b5 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -72,6 +72,16 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: """ return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member + def clone(self) -> "SpaceGenerator": + """Clone the design space generator. + + Returns + ------- + cloned_sg : SpaceGenerator + The cloned design space generator. + """ + return _ffi_api.SpaceGeneratorClone(self) # type: ignore # pylint: disable=no-member + ScheduleFnType = SpaceGenerator.ScheduleFnType @@ -89,6 +99,7 @@ def __init__( self, f_initialize_with_tune_context: Optional[Callable] = None, f_generate_design_space: Optional[Callable] = None, + f_clone: Optional[Callable] = None, ): """Constructor.""" @@ -96,6 +107,7 @@ def __init__( _ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_generate_design_space, + f_clone, ) @@ -109,7 +121,7 @@ class PySpaceGenerator: _tvm_metadata = { "cls": _PySpaceGenerator, - "methods": ["_initialize_with_tune_context", "generate_design_space"], + "methods": ["_initialize_with_tune_context", "generate_design_space", "clone"], } def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -137,6 +149,16 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: """ raise NotImplementedError + def clone(self) -> SpaceGenerator: + """Clone the design space generator. + + Returns + ------- + cloned_sg : SpaceGenerator + The cloned design space generator. + """ + raise NotImplementedError + def create( # pylint: disable=keyword-arg-before-vararg kind: Union[ diff --git a/python/tvm/meta_schedule/testing/dummy_object.py b/python/tvm/meta_schedule/testing/dummy_object.py index 50ae974df5d8..bb2294544920 100644 --- a/python/tvm/meta_schedule/testing/dummy_object.py +++ b/python/tvm/meta_schedule/testing/dummy_object.py @@ -58,3 +58,6 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {}) + + def clone(self): + return DummyMutator() diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 17acad8d4a57..29cd94110c0c 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -331,3 +331,13 @@ def notify_runner_results( "Please construct TuneContext with search_strategy" ) return self.search_strategy.notify_runner_results(measure_candidates, results) + + def clone(self) -> "TuneContext": + """Clone the TuneContext. + + Returns + ------- + cloned_context : TuneContext + The cloned TuneContext. + """ + return _ffi_api.TuneContextClone(self) # type: ignore # pylint: disable=no-member diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 9d6d69ba355f..2a31d2da9b53 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -42,6 +42,11 @@ class MutateComputeLocationNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } private: struct Candidate { diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 82b91da682c6..9feb4747d807 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -188,6 +188,11 @@ class MutateParallelNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! \brief The candidate to be mutated */ diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index de780b53e2d9..f5d89a85092b 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -42,6 +42,11 @@ class MutateThreadBindingNode : public MutatorNode { } // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } private: struct Candidate { diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 4a3bfda8a4a8..8fb83147ea7b 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -63,6 +63,11 @@ class MutateTileSizeNode : public MutatorNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index c282a171c3b7..7bbf00343af3 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -60,6 +60,11 @@ class MutateUnrollNode : public MutatorNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` Optional Apply(const Trace& trace, TRandState* rand_state) final; + // Inherit from `MutatorNode` + Mutator Clone() const final { + ObjectPtr n = make_object(*this); + return Mutator(n); + } }; /*! \brief A candidate to be mutated */ diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 43b95000c71d..25312ab61f99 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -33,13 +33,20 @@ Optional PyMutatorNode::Apply( return f_apply(trace, *rand_state); } +Mutator PyMutatorNode::Clone() const { + ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!"; + return f_clone(); +} + Mutator Mutator::PyMutator( PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyMutatorNode::FApply f_apply, // + PyMutatorNode::FClone f_clone, // PyMutatorNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return Mutator(n); } @@ -63,6 +70,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); }); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 85a81f10fdcd..8362da552ea5 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -67,6 +67,11 @@ class DisallowDynamicLoopNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + // Inherited from PostprocNode + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 0f4f1b1192f6..957d6e7364e4 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -32,13 +32,20 @@ bool PyPostprocNode::Apply(const tir::Schedule& sch) { return f_apply(sch); } +Postproc PyPostprocNode::Clone() const { + ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!"; + return f_clone(); +} + Postproc Postproc::PyPostproc( PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyPostprocNode::FApply f_apply, // + PyPostprocNode::FClone f_clone, // PyPostprocNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return Postproc(n); } @@ -58,6 +65,7 @@ TVM_REGISTER_NODE_TYPE(PyPostprocNode); TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") .set_body_method(&PostprocNode::InitializeWithTuneContext); TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d111bdb42abb..ac9f45ca8ef4 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -104,6 +104,11 @@ class RewriteCooperativeFetchNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index f4cbdfe737fb..6ff9958c791f 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -167,6 +167,11 @@ class RewriteLayoutNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.RewriteLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode); }; diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 08d25d017840..c3cc0ef60152 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -384,6 +384,12 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { return true; } + Postproc Clone() const { + ObjectPtr n = + make_object(*this); + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); }; diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index ea204e306133..05a7640f047c 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -114,6 +114,11 @@ class RewriteReductionBlockNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 3b6c438d0216..4f8e0fb213f8 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -68,6 +68,11 @@ class RewriteTensorizeNode : public PostprocNode { void VisitAttrs(tvm::AttrVisitor* v) {} + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + bool vectorize_init_loop = false; static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index eb57e90f82f6..1ba68538ea04 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -97,6 +97,11 @@ class RewriteUnboundBlockNode : public PostprocNode { // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; + Postproc Clone() const { + ObjectPtr n = make_object(*this); + return Postproc(n); + } + public: /*! \brief The max number of threads per block from Target */ int max_threads_per_block_ = -1; diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index dfe2c5a06a17..0828ee538427 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -196,6 +196,12 @@ class VerifyGPUCodeNode : public PostprocNode { return true; } + Postproc Clone() const { + ObjectPtr n = make_object(*this); + n->target_constraints_ = this->target_constraints_; + return Postproc(n); + } + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); }; diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index cf87f24ac233..2fc1352677cb 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,6 +36,12 @@ class AddRFactorNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + public: /*! * \brief The maximum number of jobs to be launched per core. diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index d8f52fa8e1de..7af1418d8f3e 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -177,6 +177,12 @@ class AutoBindNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + public: /*! \brief The max number of threads per block from Target */ int64_t max_threads_per_block_ = -1; diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 446c8ead7e8e..dcdc83f95cb1 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -60,6 +60,12 @@ class AutoInlineNode : public ScheduleRuleNode { return {sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + public: /*! \brief If allows to inline a block into its producer */ bool into_producer; diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 35be33f72e21..f2fc67f74cc7 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -113,6 +113,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode { return {tmp_sch, sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + private: /*! * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index c126c854462c..1625a27b9aaf 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -104,6 +104,12 @@ Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& return results; } +// Inherited from ScheduleRuleNode +ScheduleRule MultiLevelTilingNode::Clone() const { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); +} + std::vector MultiLevelTilingNode::ApplySubRules(std::vector states) { states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 9161a972c187..47da878c3be0 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -155,6 +155,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { // Entry of the mega rule; Inherited from ScheduleRuleNode Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const override; + protected: virtual std::vector ApplySubRules(std::vector states); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7ddda9b2635b..13b00fa7deb6 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -137,6 +137,13 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Override Apply to apply tensorization-specific analysis before applying sub-rules Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + /*! * \brief Transform and tensorize with the given tensor intrin * \param state The state of the meta schedule rule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 3a299ed041e2..b953d1ad4b50 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -63,6 +63,13 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { return res; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then // tile the outerloops. virtual std::vector ApplySubRules(std::vector states) { diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 19758996e608..045aa85b73ad 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -79,6 +79,13 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { return {sch}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + public: /*! * \brief The maximum number of jobs to be launched per CPU core. It sets the diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 65988dfd5688..7796eddd44d3 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -57,6 +57,12 @@ class RandomComputeLocationNode : public ScheduleRuleNode { return {res}; } + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + private: bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 80f8725b0c0d..416b43f46d56 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -33,13 +33,20 @@ Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, return f_apply(sch, block); } +ScheduleRule PyScheduleRuleNode::Clone() const { + ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!"; + return f_clone(); +} + ScheduleRule ScheduleRule::PyScheduleRule( PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FClone f_clone, // PyScheduleRuleNode::FAsString f_as_string) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); + n->f_clone = std::move(f_clone); n->f_as_string = std::move(f_as_string); return ScheduleRule(n); } @@ -60,6 +67,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone") + .set_body_method(&ScheduleRuleNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") .set_body_typed(ScheduleRule::PyScheduleRule); diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index c5ff9008effe..5930704eb0d1 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -431,6 +431,24 @@ class EvolutionarySearchNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(measure_candidates, results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->max_trials_per_task = this->max_trials_per_task; + n->num_trials_per_iter = this->num_trials_per_iter; + n->population_size = this->population_size; + n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; + n->init_measured_ratio = this->init_measured_ratio; + n->init_min_unmeasured = this->init_min_unmeasured; + n->genetic_num_iters = this->genetic_num_iters; + n->genetic_mutate_prob = this->genetic_mutate_prob; + n->genetic_max_fail_count = this->genetic_max_fail_count; + n->eps_greedy = this->eps_greedy; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 4574c1c817a8..6914ab2f0f0a 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -100,6 +100,16 @@ class ReplayFuncNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->num_trials_per_iter = this->num_trials_per_iter; + n->max_trials_per_task = this->max_trials_per_task; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 64fc68394357..bd553bf037d1 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -118,6 +118,17 @@ class ReplayTraceNode : public SearchStrategyNode { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } + + SearchStrategy Clone() const final { + ObjectPtr n = make_object(); + n->num_trials_per_iter = this->num_trials_per_iter; + n->max_trials_per_task = this->max_trials_per_task; + n->max_fail_count = this->max_fail_count; + n->context_ = this->context_; + n->rand_state_ = this->rand_state_; + n->state_ = nullptr; // cleared the state + return SearchStrategy(n); + } }; inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 5865fc842248..81c7fda315b4 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -59,18 +59,25 @@ void PySearchStrategyNode::NotifyRunnerResults(const Array& me f_notify_runner_results(measure_candidates, results); } +SearchStrategy PySearchStrategyNode::Clone() const { + ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!"; + return f_clone(); +} + SearchStrategy SearchStrategy::PySearchStrategy( PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PySearchStrategyNode::FPreTuning f_pre_tuning, // PySearchStrategyNode::FPostTuning f_post_tuning, // PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // - PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, // + PySearchStrategyNode::FClone f_clone) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = f_initialize_with_tune_context; n->f_pre_tuning = f_pre_tuning; n->f_post_tuning = f_post_tuning; n->f_generate_measure_candidates = f_generate_measure_candidates; n->f_notify_runner_results = f_notify_runner_results; + n->f_clone = f_clone; return SearchStrategy(n); } @@ -94,6 +101,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") .set_body_method(&SearchStrategyNode::NotifyRunnerResults); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") + .set_body_method(&SearchStrategyNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 9be89e2d9c70..991e4fa08047 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -188,6 +188,15 @@ class PostOrderApplyNode : public SpaceGeneratorNode { } return result; } + + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + n->sch_rules_ = Array(); + for (const ScheduleRule& sch_rule : this->sch_rules_) { + n->sch_rules_.push_back(sch_rule->Clone()); + } + return SpaceGenerator(n); + } static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 70559fbcf1fb..adea139b1cd4 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -72,6 +72,11 @@ class ScheduleFnNode : public SpaceGeneratorNode { throw; } + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + return SpaceGenerator(n); + } + static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 5c5ab6ebbae5..6fc31ed896f2 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -33,12 +33,18 @@ Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& m return f_generate_design_space(mod); } +SpaceGenerator PySpaceGeneratorNode::Clone() const { + ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!"; + return f_clone(); +} + SpaceGenerator SpaceGenerator::PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, - PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) { + FInitializeWithTuneContext f_initialize_with_tune_context, + FGenerateDesignSpace f_generate_design_space, FClone f_clone) { ObjectPtr n = make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_generate_design_space = std::move(f_generate_design_space); + n->f_clone = std::move(f_clone); return SpaceGenerator(n); } @@ -51,6 +57,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") .set_body_typed(SpaceGenerator::PySpaceGenerator); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") + .set_body_method(&SpaceGeneratorNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 6ea61824f932..771d0c187f97 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -47,6 +47,15 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { return design_spaces; } + SpaceGenerator Clone() const final { + ObjectPtr n = make_object(*this); + n->space_generators = Array(); + for (const SpaceGenerator& space_generator : this->space_generators) { + n->space_generators.push_back(space_generator->Clone()); + } + return SpaceGenerator(n); + } + static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 57b2344c6f8d..3650c0374dab 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -52,6 +52,32 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +TuneContext TuneContextNode::Clone() const { + ObjectPtr n = make_object(*this); + if (this->sch_rules.defined()) { + n->sch_rules = Array(); + for (const ScheduleRule& sch_rule : this->sch_rules) { + n->sch_rules.push_back(sch_rule->Clone()); + } + } + if (this->postprocs.defined()) { + n->postprocs = Array(); + for (const Postproc& postproc : this->postprocs) { + n->postprocs.push_back(postproc->Clone()); + } + } + if (this->mutator_probs.defined()) { + n->mutator_probs = Map(); + for (const auto& kv : this->mutator_probs) { + n->mutator_probs.Set(kv.first->Clone(), kv.second); + } + } + if (this->space_generator.defined()) n->space_generator = this->space_generator.value()->Clone(); + if (this->search_strategy.defined()) n->search_strategy = this->search_strategy.value()->Clone(); + n->Initialize(); + return TuneContext(n); +} + void TuneContextNode::Initialize() { if (this->space_generator.defined()) { this->space_generator.value()->InitializeWithTuneContext(GetRef(this));