Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
223f6ff
Introduce new module equality to extract only anchor block tasks
masahi Oct 12, 2022
f5ca225
enabling application of anchor trace to different subgraph
masahi Oct 18, 2022
e6a4f21
fixed anchor block extraction
masahi Oct 19, 2022
f107fd7
fixed UB in task extraction
masahi Oct 20, 2022
b2dfc18
Reworked anchor trace application and inlining logic
masahi Oct 20, 2022
a4bacd1
fixed anchor block extraction for winograd
masahi Oct 21, 2022
718a514
fix inline logic for winograd
masahi Oct 21, 2022
6bbd4d8
refactor, clean up, renaming
masahi Oct 21, 2022
fbe5160
fix reverse compute inline unapplicable case
masahi Oct 24, 2022
cf4d8b7
fixed get_block applicablity condition
masahi Oct 24, 2022
8ee4da6
adding test
masahi Oct 24, 2022
51b3766
introduce HasBlock utility
masahi Oct 25, 2022
289cd9a
Decoupled trace creation and application in Trace::ApplyJSONToschedule
masahi Oct 25, 2022
5c0d47f
add test
masahi Oct 25, 2022
fbb2361
adding more test
masahi Oct 25, 2022
0b48c14
black
masahi Oct 25, 2022
ac24ea3
Revert "Decoupled trace creation and application in Trace::ApplyJSONT…
masahi Oct 25, 2022
200a2a5
add tests
masahi Oct 25, 2022
0423cec
add doc
masahi Oct 26, 2022
abb2d0b
use anchor tuning in hexagon int8 tuning test
masahi Oct 26, 2022
1e7db84
cpplint
masahi Oct 26, 2022
c84bf21
suppress mypy on ffi
masahi Oct 26, 2022
346d55f
add workaround for false positive maybe-uninitialized warning
masahi Oct 26, 2022
b88e63b
add a minimal anchor tuning test
masahi Oct 26, 2022
2f37900
relax tol for i386, remove gpu test since it requires sm86
masahi Oct 27, 2022
5e2367a
add doc for "anchor-block" module equality
masahi Oct 27, 2022
d6893af
address comments
masahi Oct 28, 2022
1fe554d
add test for cache_write + AllocateConst bug
masahi Oct 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class DatabaseNode : public runtime::Object {
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
* - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
* given module. The "ignore-ndarray" varint is used for the extracted blocks
* or in case no anchor block is found.
* For the definition of the anchor block, see tvm/tir/analysis.h.
*/
explicit DatabaseNode(String mod_eq_name = "structural");

Expand Down Expand Up @@ -274,6 +278,10 @@ class PyDatabaseNode : public DatabaseNode {
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
* - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
* given module. The "ignore-ndarray" varint is used for the extracted blocks
* or in case no anchor block is found.
* For the definition of the anchor block, see tvm/tir/analysis.h.
*/
explicit PyDatabaseNode(String mod_eq_name = "structural");

Expand Down
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,12 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx
* \param max_threadblocks The maximum number of threadblock on GPU
* \param thread_extents Candidates of thread axis extent.
* \param max_threads_per_block The maximum number of threads per block, if it is known
* when this schedule rule is created.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents);
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
int max_threads_per_block = -1);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
search_strategy,
space_generator,
tir_integration,
trace_apply,
)
from .builder import Builder
from .cost_model import CostModel
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/database/json_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class JSONDatabase(Database):
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
"""

path_workload: str
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class MemoryDatabase(Database):
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
"""

def __init__(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/database/schedule_fn_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class ScheduleFnDatabase(Database):
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
"""

def __init__(
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def extract_tasks(
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.

Returns
-------
Expand Down Expand Up @@ -288,6 +292,10 @@ def tune_relay(
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.

Returns
-------
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/meta_schedule/trace_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Specialized applications of trace"""
from ..tir.schedule import Schedule, Trace
from ..target import Target
from . import _ffi_api


def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target: Target) -> None:
"""Apply the trace from a TIR module whose anchor block is the same but fused elemewise op
blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add
subgraph to other subgraphs having the same conv2d workload, for example. We call such trace
an "anchor trace". Those blocks that are not scheduled by the given anchor trace will be either
inlined or parallelized.

Parameters
----------
sch : Schedule
The target schedule
anchor_trace: Trace
The trace generated for other TIR module having the same anchor block
target : tvm.target.Target
The compilation target
"""
_ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def tune_tasks(
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.

Returns
-------
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
# add all floating point and integer datatypes to the module
for _dtype in ["float", "uint", "int"]:
for _size in ["8", "16", "32", "64"]:
for _lanes in ["", "x4", "x8", "x16", "x32"]:
for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]:
from . import ty

_name = _dtype + _size + _lanes
globals()[_name] = getattr(ty, _name)
if hasattr(ty, _name):
globals()[_name] = getattr(ty, _name)
2 changes: 1 addition & 1 deletion python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __getitem__(self, args):
# add all floating point and integer datatypes to the module
for _dtype in ["float", "uint", "int"]:
for _size in ["8", "16", "32", "64"]:
for _lanes in ["", "x4", "x8", "x16", "x32"]:
for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]:
_name = _dtype + _size + _lanes
globals()[_name] = ConcreteType(_name)

Expand Down
18 changes: 18 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,21 @@ def get_auto_tensorize_mapping_info(
intrinsics.
"""
return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore


def has_block(sch: Schedule, block_name: str) -> bool:
"""Query if the given block name exists in the module associated with the provided schedule.

Parameters
----------
sch : Schedule
The schedule
block_name : str
The name of the block to query

Returns
-------
yes/no: bool
True if the given block exists in the schedule.
"""
return _ffi_api.HasBlock(sch, block_name) # type: ignore
24 changes: 24 additions & 0 deletions src/meta_schedule/module_equality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/analysis.h>

#include <memory>

Expand Down Expand Up @@ -73,11 +74,34 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality {
}
};

// The NDArray-ignoring variant of structural equal / hash is used for the module equality
// on the extracted anchor blocks.
class ModuleEqualityAnchorBlock : public ModuleEquality {
size_t Hash(IRModule mod) const {
auto anchor_block = tir::FindAnchorBlock(mod);
if (anchor_block) {
return SHashHandlerIgnoreNDArray().Hash(GetRef<tir::Block>(anchor_block), false);
}
return ModuleEqualityIgnoreNDArray().Hash(mod);
}
bool Equal(IRModule lhs, IRModule rhs) const {
auto anchor_block_lhs = tir::FindAnchorBlock(lhs);
auto anchor_block_rhs = tir::FindAnchorBlock(rhs);
if (anchor_block_lhs && anchor_block_rhs) {
return SEqualHandlerIgnoreNDArray().Equal(GetRef<tir::Block>(anchor_block_lhs),
GetRef<tir::Block>(anchor_block_rhs), false);
}
return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs);
}
};

std::unique_ptr<ModuleEquality> ModuleEquality::Create(const std::string& mod_eq_name) {
if (mod_eq_name == "structural") {
return std::make_unique<ModuleEqualityStructural>();
} else if (mod_eq_name == "ignore-ndarray") {
return std::make_unique<ModuleEqualityIgnoreNDArray>();
} else if (mod_eq_name == "anchor-block") {
return std::make_unique<ModuleEqualityAnchorBlock>();
}
LOG(FATAL) << "Unknown module equality " << mod_eq_name;
return nullptr;
Expand Down
4 changes: 4 additions & 0 deletions src/meta_schedule/module_equality.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class ModuleEquality {
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
* - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
* given module. The "ignore-ndarray" varint is used for the extracted blocks
* or in case no anchor block is found.
* For the definition of the anchor block, see tvm/tir/analysis.h.
* \return An owning pointer to the created instance
*/
static std::unique_ptr<ModuleEquality> Create(const std::string& mod_eq_name);
Expand Down
5 changes: 3 additions & 2 deletions src/meta_schedule/schedule_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl
return {sch};
}

ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents) {
ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents,
int max_threads_per_block) {
ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>();
n->max_threadblocks_ = max_threadblocks;
n->max_threads_per_block_ = -1;
n->max_threads_per_block_ = max_threads_per_block;
n->thread_extents_ = std::move(thread_extents);
return ScheduleRule(n);
}
Expand Down
18 changes: 2 additions & 16 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,7 @@ ScheduleRule ScheduleRule::PyScheduleRule(

Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
return {
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
/*inline_const_tensor=*/true,
/*disallow_if_then_else=*/true,
/*require_injective=*/true,
/*require_ordered=*/true,
/*disallow_op=*/Array<String>{"tir.exp"}),
GetDefaultAutoInline("llvm"),
ScheduleRule::AddRFactor(
/*max_jobs_per_core=*/16,
/*max_innermost_factor=*/Integer(64)),
Expand Down Expand Up @@ -98,14 +91,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{3}}, //
{"scope", String("local")}}),
ScheduleRule::AutoInline(
/*into_producer=*/true,
/*into_consumer=*/true,
/*inline_const_tensor=*/true,
/*disallow_if_then_else=*/false,
/*require_injective=*/false,
/*require_ordered=*/false,
/*disallow_op=*/Array<String>{}),
GetDefaultAutoInline("cuda"),
ScheduleRule::CrossThreadReduction(
/*thread_extents=*/Array<Integer>{4, 8, 16, 32, 64, 128, 256, 512}),
ScheduleRule::ParallelizeVectorizeUnroll(
Expand Down
7 changes: 3 additions & 4 deletions src/meta_schedule/space_generator/space_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ String GetRuleKindFromTarget(const Target& target) {
}
return "cuda";
}
if (target->kind->name == "rocm") {
return "cuda";
}
if (target->kind->name == "vulkan") {

if (IsGPUTarget(target->kind->name)) {
return "cuda";
}

LOG(FATAL) << "Unsupported target: " << target;
throw;
}
Expand Down
Loading