Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,11 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
/*!
* \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees
* rooted at its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* that the specified sref is in to be local complete block or local reduction block.
* \param self The schedule state
* \param subtree_root The sref of the subtree root to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact
* dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
*/
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref);
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root);
/*!
* \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is
* not allocated under the current scope
Expand Down
109 changes: 81 additions & 28 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,53 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) {
return std::move(visitor.result);
}

/*!
* \brief Check whether the given sref_a is higher than or equal to sref_b.
*/
void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) {
const StmtSRefNode* p = sref_b.get();
for (; p != nullptr; p = p->parent) {
if (p == sref_a.get()) {
return;
}
}
CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b;
}

/*!
* \brief Check the dominant property of a block:
* the block is the only writer of its output, dominating the reader of its output buffers
* \param scope The block-scope of the block to be checked
* \param block_sref The block whose dominant property is to be checked
* \return A boolean indicating if the block is a dominant block
* the block is the only writer of its output, dominating the reader of its output buffers under the
* given root scope.
* \param self The schedule state.
* \param scope_root_sref The StmtSRef corresponding to the root scope.
* \param block_sref The block whose dominant property is to be checked.
* \return A boolean indicating if the block is a dominant block.
*/
bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
const StmtSRef& block_sref) {
std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
CheckSRefHigherOrEqual(scope_root_sref, block_sref);
const BlockNode* maybe_root_block = scope_root_sref->StmtAs<BlockNode>();
if (maybe_root_block) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
buffer_writers = scope->buffer_writers;
} else {
// Collect all child blocks of root sub-tree, and merge their buffer writers.
Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref);
for (const StmtSRef& child_block_sref : child_block_srefs) {
BlockScope child_scope = self->GetBlockScope(child_block_sref);
for (const auto& it : child_scope->buffer_writers) {
buffer_writers.insert(it);
}
}
}
// Check whether the input block is the only writer of its outputs
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>& buffer_writers =
scope->buffer_writers;
for (const BufferRegion& write_region : block->writes) {
ICHECK(buffer_writers.count(write_region->buffer))
<< "InternalError: buffer \"" << write_region->buffer->name
<< "\" does not exist in the current scope, when querying block:\n"
<< GetRef<Block>(block);
if (buffer_writers.at(write_region->buffer).size() != 1) {
return false;
if (buffer_writers.count(write_region->buffer)) {
if (buffer_writers.at(write_region->buffer).size() != 1) {
return false;
}
}
}
return true;
Expand All @@ -178,7 +206,6 @@ bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
*/
int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
// Cond 1. All block vars are data parallel
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const IterVar& iter_var : block->iter_vars) {
Expand All @@ -188,7 +215,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block
}
// Cond 2. Dominant: the block is the only writer of its output,
// dominating the reader of its output buffers
if (!IsDominantBlock(scope, block_sref)) {
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
return 2;
}
// Cond 3. No overlap between the buffers the block reads and writes
Expand Down Expand Up @@ -217,6 +244,18 @@ static const char* kReductionBlockDefinition = R"(Definition of a reduction bloc
4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers
5) The reduction block vars are not used to index the output buffers)";

static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes)";

static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers)";

bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0;
Expand Down Expand Up @@ -260,7 +299,6 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
*/
int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Cond 1. The block has the `init` statement.
if (!block->init.defined()) {
Expand All @@ -277,7 +315,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc
}
// Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
// output buffers.
if (!IsDominantBlock(scope, block_sref)) {
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
return 4;
}
// Cond 5. The reduction block vars are not used to index the output buffers.
Expand Down Expand Up @@ -363,40 +401,55 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
reduction_block_error_code);
}

void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref) {
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) {
class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block,
int local_complete_block_code, int local_reduction_block_code)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
violate_block_(std::move(violate_block)),
local_complete_block_code_(local_complete_block_code),
local_reduction_block_code_(local_reduction_block_code) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
"because some of its child block on SRef tree is neither a local complete block nor a "
"local reduction block.";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
std::ostringstream os;
os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a local complete block nor a local "
"reduction block.\n";
os << "It violates condition #" << local_complete_block_code_
<< " as a local complete block.\n";
os << kLocalCompleteBlockDefinition << "\n";
os << "It violates condition #" << local_reduction_block_code_
<< " as a local reduction block.\n";
os << kLocalReductionBlockDefinition << "\n";
return os.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
int local_complete_block_code_;
int local_reduction_block_code_;
};

Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root),
local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root);
if (local_complete_block_code != 0 && local_reduction_block_code != 0) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
GetRef<Block>(block));
GetRef<Block>(block), local_complete_block_code,
local_reduction_block_code);
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref,
/*require_stage_pipeline=*/true);
CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref);
CheckSubtreeCompactDataflow(self, loop_sref);

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_tir_schedule_for_kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,72 @@ def decomposed_gemm_after_vectorize(
C[vi, vj] = local[vi, vj]


@T.prim_func
def decomposed_gemm_parallelize_init(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
local = T.alloc_buffer([16, 16], dtype="float32")
for i, j in T.grid(4, 4):
for ii in T.serial(4):
for jj in T.vectorized(4):
with T.block("init"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
T.reads()
T.writes(local[vi, vj])
local[vi, vj] = 0
for k, ii, jj in T.grid(16, 4, 4):
with T.block("update"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
vk = T.axis.reduce(16, k)
T.reads(local[vi, vj], A[vi, vk], B[vj, vk])
T.writes(local[vi, vj])
local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk]
for ii, jj in T.grid(4, 4):
with T.block("C"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
T.reads(local[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = local[vi, vj]


@T.prim_func
def scatter_compute(A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]):
for i in T.grid(8):
with T.block("first_half"):
vi = T.axis.spatial(16, 8 + i)
B[vi] = A[vi - 8]

for i in T.grid(8):
with T.block("last_half"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi + 8]


@T.prim_func
def scatter_compute_parallelize(
A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]
) -> None:
# body
# with T.block("root")
for i in T.parallel(8):
with T.block("first_half"):
vi = T.axis.spatial(16, 8 + i)
T.reads(A[vi - 8])
T.writes(B[vi])
B[vi] = A[vi - 8]
for i in T.parallel(8):
with T.block("last_half"):
vi = T.axis.spatial(16, i)
T.reads(A[vi + 8])
T.writes(B[vi])
B[vi] = A[vi + 8]


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -468,5 +534,28 @@ def test_vectorize_after_decompose():
verify_trace_roundtrip(s, mod=decomposed_gemm)


def test_vectorize_init():
s = tir.Schedule(decomposed_gemm, debug_mask="all")
init_blk = s.get_block("init")
upd_blk = s.get_block("update")
_, _, ii_0, jj_0 = s.get_loops(init_blk)
_, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk)
s.vectorize(jj_0)
tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init)
verify_trace_roundtrip(s, mod=decomposed_gemm)


def test_scatter_parallelize():
s = tir.Schedule(scatter_compute, debug_mask="all")
first = s.get_block("first_half")
last = s.get_block("last_half")
(i_0,) = s.get_loops(first)
(i_1,) = s.get_loops(last)
s.parallel(i_0)
s.parallel(i_1)
tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize)
verify_trace_roundtrip(s, mod=scatter_compute)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))