Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,14 @@ class SpaceGenerator : public runtime::ObjectRef {
TVM_DLL static SpaceGenerator PySpaceGenerator(
PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);

/*!
* \brief Create a design space generator with customized schedule function.
* \param schedule_fn The schedule function, which can have the following signatures:
* 1) void(Schedule)
* 2) Schedule(Schedule)
* 3) Array<Schedule>(Schedule)
*/
TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn);
/*!
* \brief Create a design space generator that is union of multiple design space generators.
* \param space_generators An array of design space generators to be unioned.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
space for generation of measure candidates.
"""
from .post_order_apply import PostOrderApply
from .schedule_fn import SCH_FN_TYPE, ScheduleFn
from .space_generator import PySpaceGenerator, SpaceGenerator
from .schedule_fn import ScheduleFn
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
88 changes: 22 additions & 66 deletions python/tvm/meta_schedule/space_generator/schedule_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,78 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Meta schedule design space generators that generates design
space via a schedule function.
"""
from typing import TYPE_CHECKING, Callable, List, Union
"""Union of meta Schedule design space generators."""
from tvm._ffi import register_object

from tvm.ir import IRModule
from tvm.ir.container import Array
from tvm.meta_schedule.utils import derived_object
from tvm.tir.schedule import Schedule
from .. import _ffi_api
from .space_generator import SpaceGenerator

from .space_generator import PySpaceGenerator

if TYPE_CHECKING:
from ..tune_context import TuneContext
@register_object("meta_schedule.ScheduleFn")
class ScheduleFn(SpaceGenerator):
"""Create a design space generator with customized schedule function.
The schedule function can have the following signatures:
- 1) [Schedule] -> None
- 2) [Schedule] -> Schedule
- 3) [Schedule] -> List[Schedule]
"""

SCH_FN_TYPE = Union[ # pylint: disable=invalid-name
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]


@derived_object
class ScheduleFn(PySpaceGenerator):
"""A design space generator with design spaces specified by a schedule function."""

def __init__(self, sch_fn: SCH_FN_TYPE):
def __init__(self, sch_fn: SpaceGenerator.ScheduleFnType):
"""Constructor.

Parameters
----------
sch_fn : SCH_FN_TYPE
The schedule function.
"""
super().__init__()
self.sch_fn = sch_fn

def _initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.

Parameters
----------
context : TuneContext
The tuning context for initializing the design space generator.
"""

def generate_design_space(self, mod: IRModule) -> List[Schedule]:
"""Generate design spaces given a module.

Parameters
----------
mod : IRModule
The module used for design space generation.

Returns
-------
design_spaces : List[Schedule]
The generated design spaces, i.e., schedules.
sch_fn : SpaceGenerator.ScheduleFnType
The schedule function, which can have the following signatures:
- 1) [Schedule] -> None
- 2) [Schedule] -> Schedule
- 3) [Schedule] -> List[Schedule]
"""
sch = Schedule(mod) # Make sure the schedule is traced
result = self.sch_fn(sch) # Call the schedule function
if result is None: # Case 1. No output
return [sch]
if isinstance(result, Schedule): # Case 2. Single output
return [result]
if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs
for ret in result: # enumerate the outputs
if not isinstance(ret, Schedule):
raise TypeError(
"Wrong type of element in the list, expected Schedule got "
+ f"'{type(ret)}': {ret}"
)
return result
raise TypeError(f"Unexpected return type {type(result)}: {result}")
self.__init_handle_by_constructor__(
_ffi_api.SpaceGeneratorScheduleFn, # type: ignore # pylint: disable=no-member
sch_fn,
)
11 changes: 10 additions & 1 deletion python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Union

from tvm._ffi import register_object
from tvm.ir import IRModule
Expand All @@ -35,6 +35,12 @@
class SpaceGenerator(Object):
"""The abstract design space generator interface."""

ScheduleFnType = Union[
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]

def _initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.

Expand Down Expand Up @@ -63,6 +69,9 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]:
return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member


ScheduleFnType = SpaceGenerator.ScheduleFnType


@register_object("meta_schedule.PySpaceGenerator")
class _PySpaceGenerator(SpaceGenerator):
"""
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .runner import RunnerResult
from .schedule_rule import ScheduleRule
from .search_strategy import MeasureCandidate, SearchStrategy
from .space_generator import SCH_FN_TYPE, ScheduleFn, SpaceGenerator
from .space_generator import ScheduleFn, ScheduleFnType, SpaceGenerator
from .tune import TuneConfig


Expand All @@ -55,7 +55,7 @@ class TuneContext(Object):
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None
space_generator : Union[None, ScheduleFnType, SpaceGenerator] = None
The design space generator.
search_strategy : Union[None, TuneConfig, SearchStrategy] = None
The search strategy.
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
mod: Optional[IRModule] = None,
*,
target: Optional[Target] = None,
space_generator: Union[None, "SCH_FN_TYPE", "ScheduleFn", "SpaceGenerator"] = None,
space_generator: Union[None, "ScheduleFnType", "ScheduleFn", "SpaceGenerator"] = None,
search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None,
sch_rules: Union[None, str, List["ScheduleRule"]] = None,
postprocs: Union[None, str, List["Postproc"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
this->logging_func = context->logging_func;
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
tir::Schedule sch = tir::Schedule::Traced(
/*mod=*/mod_,
/*mod=*/mod,
/*rand_state=*/ForkSeed(&this->rand_state_),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
Expand Down
90 changes: 90 additions & 0 deletions src/meta_schedule/space_generator/schedule_fn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief The union of design space generators. */
class ScheduleFnNode : public SpaceGeneratorNode {
public:
/*! \brief The random state. -1 means using random number. */
TRandState rand_state_ = -1;
/*! \brief The schedule function. */
runtime::PackedFunc schedule_fn_;

void VisitAttrs(tvm::AttrVisitor* v) {
// `schedule_fn_` is not visited.
}

void InitializeWithTuneContext(const TuneContext& context) final {
this->rand_state_ = ForkSeed(&context->rand_state);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
tir::Schedule sch = tir::Schedule::Traced(
/*mod=*/mod,
/*rand_state=*/ForkSeed(&this->rand_state_),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
runtime::TVMRetValue rv;
rv = this->schedule_fn_(sch);
if (rv.type_code() == kTVMNullptr) {
return {sch};
}
ObjectRef obj = rv;
if (const auto* sch = obj.as<tir::ScheduleNode>()) {
return {GetRef<tir::Schedule>(sch)};
}
if (const auto* arr = obj.as<runtime::ArrayNode>()) {
Array<tir::Schedule> result;
result.reserve(arr->size());
for (const ObjectRef& obj : *arr) {
if (const auto* sch = obj.as<tir::ScheduleNode>()) {
result.push_back(GetRef<tir::Schedule>(sch));
} else {
LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or "
"List[Schedule], but got: "
<< obj->GetTypeKey();
}
}
return result;
}
LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or "
"List[Schedule], but got: "
<< obj->GetTypeKey();
throw;
}

static constexpr const char* _type_key = "meta_schedule.ScheduleFn";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode);
};

SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn) {
ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>();
n->schedule_fn_ = std::move(schedule_fn);
return SpaceGenerator(n);
}

TVM_REGISTER_NODE_TYPE(ScheduleFnNode);
TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn")
.set_body_typed(SpaceGenerator::ScheduleFn);

} // namespace meta_schedule
} // namespace tvm