Skip to content
88 changes: 54 additions & 34 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class Mutator;

/*! \brief Mutator is designed to mutate the trace to explore the design space. */
class MutatorNode : public runtime::Object {
Expand All @@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object {
virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) = 0;

/*!
* \brief Clone the mutator.
* \return The cloned mutator.
*/
virtual Mutator Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.Mutator";
TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
};

/*! \brief The mutator with customized methods on the python-side. */
class PyMutatorNode : public MutatorNode {
/*!
* \brief Managed reference to MutatorNode
* \sa MutatorNode
*/
class Mutator : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode {
*/
using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
/*!
* \brief Clone the mutator.
* \return The cloned mutator.
*/
using FClone = runtime::TypedPackedFunc<Mutator()>;
/*!
* \brief Get the mutator as string with name.
* \return The string of the mutator.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
};

/*!
* \brief Managed reference to MutatorNode
* \sa MutatorNode
*/
class Mutator : public runtime::ObjectRef {
public:
/*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
Expand Down Expand Up @@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef {
* \brief Create a mutator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The mutator created.
*/
TVM_DLL static Mutator PyMutator(
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyMutatorNode::FApply f_apply, //
PyMutatorNode::FAsString f_as_string);
TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};

/*! \brief The mutator with customized methods on the python-side. */
class PyMutatorNode : public MutatorNode {
public:
using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext;
using FApply = Mutator::FApply;
using FClone = Mutator::FClone;
using FAsString = Mutator::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_clone` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;
Mutator Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
86 changes: 53 additions & 33 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class Postproc;

/*!
* \brief Rules to apply a postprocessor to a schedule.
Expand All @@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object {
*/
virtual bool Apply(const tir::Schedule& sch) = 0;

/*!
* \brief Clone the postprocessor.
* \return The cloned postprocessor.
*/
virtual Postproc Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
};

/*! \brief The postprocessor with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
/*!
* \brief Managed reference to PostprocNode
* \sa PostprocNode
*/
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode {
* \return Whether the postprocessor was successfully applied.
*/
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
/*!
* \brief Clone the postprocessor.
* \return The cloned postprocessor.
*/
using FClone = runtime::TypedPackedFunc<Postproc()>;
/*!
* \brief Get the postprocessor function as string with name.
* \return The string of the postprocessor function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
};

/*!
* \brief Managed reference to PostprocNode
* \sa PostprocNode
*/
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief Create a postprocessor with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The postprocessor created.
*/
TVM_DLL static Postproc PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string);
TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
/*!
* \brief Create a postprocessor that checks if all loops are static
* \return The postprocessor created
Expand Down Expand Up @@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

/*! \brief The postprocessor with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
public:
using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
using FApply = Postproc::FApply;
using FClone = Postproc::FClone;
using FAsString = Postproc::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_clone` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;
Postproc Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
86 changes: 54 additions & 32 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class ScheduleRule;

/*! \brief Rules to modify a block in a schedule. */
class ScheduleRuleNode : public runtime::Object {
Expand All @@ -59,12 +60,21 @@ class ScheduleRuleNode : public runtime::Object {
virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
const tir::BlockRV& block) = 0;

/*!
* \brief Deep clone the schedule rule.
* \return The cloned schedule rule.
*/
virtual ScheduleRule Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -84,33 +94,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
* \return The string of the schedule rule.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `Clone` method.
* \return The cloned schedule rule.
*/
using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
Expand Down Expand Up @@ -249,16 +237,50 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyScheduleRuleNode::FApply f_apply, //
PyScheduleRuleNode::FAsString f_as_string);
FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
public:
using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
using FApply = ScheduleRule::FApply;
using FClone = ScheduleRule::FClone;
using FAsString = ScheduleRule::FAsString;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
// `f_clone` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
ScheduleRule Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
Loading