Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,40 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA
}
}

int CalculateNumRewritableLoops(const Array<StmtSRef>& loop_srefs,
const std::vector<int>& loop_types) {
int rw_loops_num = 0;
ICHECK_EQ(loop_srefs.size(), loop_types.size());
for (size_t i = 0; i < loop_srefs.size(); ++i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
continue;
}
// Cannot vectorize reduce axis
if (loop_types[i] != IterVarType::kDataPar) {
continue;
}
// Cannot fuse with a loop with multiple children
if (!IsSingleStmt(loop->body)) {
continue;
}
// Check if the loop extent is valid
if (GetLoopIntExtent(loop_sref) == nullptr) {
continue;
}
++rw_loops_num;
}
return rw_loops_num;
}

void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) {
StmtSRef block_sref = sch->GetSRef(block_rv);
if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) {
return;
}
int n_loops = loop_rvs.size();
const int n_loops = loop_rvs.size();
if (n_loops == 0) {
parsed->max_parallel_extent = -1;
parsed->max_vectorize_extent = -1;
Expand Down Expand Up @@ -226,6 +253,10 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
}
max_fusible = std::min(max_fusible, fusible);
}

// Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization.
int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types);

// Calculate the parallelize extent
if (parsed->max_parallel_extent != -1) {
int max_extent = parsed->max_parallel_extent;
Expand Down Expand Up @@ -290,10 +321,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
num_fusible = -1;
}
}
// Prefer num_vectorize to num_parallel

if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) {
parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, //
n_loops - parsed->num_vectorize_loops);
if (max_rw_loops == n_loops && max_fusible == n_loops) {
// All loops can be fused, parallelized and vectorized
parsed->num_parallel_loops = n_loops;
parsed->num_vectorize_loops = n_loops;
} else {
// Prefer num_vectorize to num_parallel
parsed->num_parallel_loops =
std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops);
}
}
}

Expand All @@ -317,6 +355,21 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block
return false;
}

void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array<LoopRV>* loop_rvs, int vec_len) {
size_t n_loops = loop_rvs->size();
LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()});
Array<LoopRV> split = sch->Split(fused, {NullOpt, Integer(vec_len)});
ICHECK_EQ(split.size(), 2);
const LoopRV& outer = split[0];
const LoopRV& inner = split[1];
sch->Parallel(outer);
sch->Vectorize(inner);
for (size_t i = 0; i < n_loops - 1; ++i) {
loop_rvs->Set(i, outer);
}
loop_rvs->Set(n_loops - 1, inner);
}

void RewriteParallel(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) {
ICHECK_LE(n, loop_rvs->size());
LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n});
Expand Down Expand Up @@ -364,13 +417,19 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode {
}
tir::ParsedAnnotation parsed = parsed_root;
tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed);
// Parallel
if (parsed.num_parallel_loops > 0) {
tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs);
}
// Vectorize
if (parsed.num_vectorize_loops > 0) {
tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs);
const int loops_num = loop_rvs.size();
if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) {
// Fuse, split, vectorize and parallelize
tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent);
} else {
// Parallel
if (parsed.num_parallel_loops > 0) {
tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs);
}
// Vectorize
if (parsed.num_vectorize_loops > 0) {
tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs);
}
}
// AutoUnroll
if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,42 @@ def after_matmul_vectorize(
T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]


@T.prim_func
def before_postproc_add(
lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"),
) -> None:
with T.block("root"):
T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128})
for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32):
with T.block("add_compute"):
v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1])
T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4])
T.writes(add_compute[v0, v1, v2, v3, v4])
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]


@T.prim_func
def after_postproc_add(
lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"),
) -> None:
with T.block("root"):
for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272):
for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128):
with T.block("add_compute"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352)
v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792)
v3 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 1792 // 32)
v4 = T.axis.spatial(32, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 32)
T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4])
T.writes(add_compute[v0, v1, v2, v3, v4])
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable

Expand All @@ -161,6 +197,14 @@ def test_vectorize_inner_loop():
tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)


def test_parallel_vectorize_add():
sch = Schedule(before_postproc_add)
rule = RewriteParallelVectorizeUnroll()
assert rule.apply(sch)
tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add)


if __name__ == "__main__":
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()
test_vectorize_inner_loop()
test_parallel_vectorize_add()