From cabda5473b04b1d03c24c3b5a39b2da0e3892843 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 Dec 2021 20:15:36 -0500 Subject: [PATCH 01/20] WIP --- include/tvm/arith/iter_affine_map.h | 7 + include/tvm/tir/schedule/schedule.h | 6 + src/tir/schedule/concrete_schedule.cc | 9 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 3 + .../schedule/primitive/blockize_tensorize.cc | 346 ++++++++++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 11 + src/tir/schedule/traced_schedule.h | 1 + .../unittest/test_tir_schedule_blockize.py | 227 ++++++++++++ 10 files changed, 613 insertions(+) create mode 100644 src/tir/schedule/primitive/blockize_tensorize.cc create mode 100644 tests/python/unittest/test_tir_schedule_blockize.py diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 22b4cd580e18..eb69c188abf3 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -350,6 +350,13 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer, DiagnosticContext diag_ctx); +/*! + * \brief Given an IterMapExpr, transform it to normal PrimExpr. + * \param expr The input IterMapExpr. + * \return The corresponding normal PrimExpr. + */ +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 43f2379a0b56..411b20d6a855 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -473,6 +473,12 @@ class ScheduleNode : public runtime::Object { */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ + /*! + * \brief Convert the subtree rooted by a specific loop into a block. + * \param loop_rv The root of the subtree + * \return The new block + */ + virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9f8dc6dd2daf..7e694a938550 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -606,6 +606,15 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} + /******** Schedule: Annotation ********/ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 96cb0f728835..e95a432985b6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -123,6 +123,7 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f0b38af01b5f..09df62e5025b 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -378,6 +378,9 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer const String& storage_scope); /******** Schedule: Blockize & Tensorize ********/ + +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); + /******** Schedule: Annotation ********/ /*! * \brief Annotate a block/loop with a key value pair diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc new file mode 100644 index 000000000000..4b1f3fe411c9 --- /dev/null +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../../arith/pattern_match.h" +#include "../../ir/functor_common.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +Array> TrivialSubspaceDivision(const Array& iter_vars, + const Array& bindings, + const std::vector& outer_loops, + const std::vector& inner_loops, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + std::vector> res; + std::unordered_set outer_loop_vars; + std::unordered_set inner_loop_vars; + for (const Var& var : outer_loops) { + outer_loop_vars.insert(var.get()); + } + for (const Var& var : inner_loops) { + inner_loop_vars.insert(var.get()); + } + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = UsesVar( + bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); + bool inner = UsesVar( + bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + bool is_var = bindings[i]->IsInstance(); + if (outer && !inner) { + arith::IterMark outer{nullptr}; + if (is_var) { + outer = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + outer = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (inner && !outer) { + arith::IterMark inner{nullptr}; + if (is_var) { + inner = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + inner = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (!outer && !inner) { + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + /*! + * Check: + * - The sub AST is one-line with only one block + * + * Mutate: + * - extra block var from the only block + * - Update block binding + */ + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); + StmtSRef block_sref = self->stmt2ref.at(block_realize.get()); + arith::Analyzer analyzer; + + // Step 1: Collect loops inside/outside loop_sref + std::vector outer_loops, inner_loops; + std::vector outer_iters, inner_iters; + std::unordered_map iters; + bool inner = true; + Block block = block_realize->block; + for (StmtSRef current_sref = block_sref;;) { + current_sref = GetRef(current_sref->parent); + if (!current_sref.defined()) break; + const auto* current_loop = current_sref->StmtAs(); + if (!current_loop) break; + if (inner) { + inner_loops.push_back(current_loop); + inner_iters.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_iters.push_back(current_loop->loop_var); + } + iters[current_loop->loop_var] = Range::FromMinExtent(current_loop->min, current_loop->extent); + if (current_sref == loop_sref) inner = false; + } + + // Step 2: Calculate subspace division for the inner loops + Array> division = arith::SubspaceDivide( + block_realize->iter_values, iters, inner_iters, block_realize->predicate, false, &analyzer); + if (division.empty()) { + // It is possible to blockize if we can not do perfect subspace division if we can divide + // the block var bindings into two categories + // 1. The binding covers no inner loop var + // 2. The binding covers only inner loop vars + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, outer_iters, + inner_iters, block_realize->predicate); + } + CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + + // Step 3: Generate a new inner block + Array inner_block_vars, outer_block_vars; + Array inner_bindings, outer_bindings; + std::unordered_map bv_iters; // iter_vars of the inner block + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const arith::IterMapExprNode* outer_binding = + division[i][0]->source.as(); + const arith::IterMapExprNode* inner_binding = + division[i][1]->source.as(); + ICHECK(outer_binding); + ICHECK(inner_binding); + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(iter_var); + // bv_iters[iter_var->var] = Range::FromMinExtent(0, division[i][0]->extent); + } else { + const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), + iter_var->var.copy_with_suffix("o"), iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(outer_var); + // generate a new iter var for outer block + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + if (const auto* op = division[i][1]->source.as()) { + base = base + op->base; + inner_bindings.push_back(base + + arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); + } else { + inner_bindings.push_back( + base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + } + inner_block_vars.push_back(iter_var); + bv_iters[iter_var->var] = Range::FromMinExtent(base, division[i][1]->extent); + } + } + Block inner_block = block; + inner_block.CopyOnWrite()->iter_vars = inner_block_vars; + inner_block.CopyOnWrite()->init = NullOpt; + BlockRealize inner_br = block_realize; + inner_br.CopyOnWrite()->iter_values = inner_bindings; + inner_br.CopyOnWrite()->predicate = division.back()[1]->extent; + inner_br.CopyOnWrite()->block = inner_block; + // Regenerate inner_loops + Stmt body = inner_br; + for (const auto& inner_loop : inner_loops) { + auto loop_node = make_object(*inner_loop); + loop_node->body = body; + body = For(loop_node); + } + // Regenerate init for outer block + Optional new_init = NullOpt; + if (block->init.defined()) { + std::vector init_loops; + std::vector init_block_vars; + std::vector init_block_vars_copy; + std::vector init_bindings; + std::unordered_map binding_replace_map; + std::unordered_map bv_replace_map; + std::unordered_map new_block_vars2old_index; + for (size_t i = 0; i < inner_block_vars.size(); ++i) { + if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { + // copy init block vars and ignore reduce block vars + init_block_vars.push_back(i); + IterVar init_block_var = inner_block_vars[i]; + init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); + init_block_vars_copy.push_back(init_block_var); + bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; + new_block_vars2old_index[init_block_var.get()] = i; + } + } + for (const ForNode* inner_loop : inner_loops) { + for (size_t i = 0; i < init_block_vars.size(); ++i) { + if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], + [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { + // copy loops related to init block vars + For init_loop = GetRef(inner_loop); + init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); + // replace loop vars with copied loop vars + binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; + init_loops.push_back(init_loop); + break; + } + } + } + for (size_t i = 0; i < init_block_vars.size(); ++i) { + init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); + } + new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt), + bv_replace_map); + new_init = + BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init.value())); + for (const auto& init_loop : init_loops) { + For new_init_loop = init_loop; + new_init_loop.CopyOnWrite()->body = new_init.value(); + new_init = new_init_loop; + } + } + // Calculate outer block's IO region + auto rewrite_range = [&](const Range& range) -> Range { + const Array& res = + arith::DetectIterMap({range->min}, bv_iters, true, false, &analyzer); + ICHECK_EQ(res.size(), 1); + const arith::IterSumExpr& normalized_expr = res[0]; + PrimExpr extent = 1; + if (normalized_expr->args.size() == 1) { + CHECK(analyzer.CanProve(normalized_expr->args[0]->scale - range->extent == 0)); + extent = normalized_expr->args[0]->extent; + } + return Range::FromMinExtent(normalized_expr->base, extent * range->extent); + }; + std::vector reads, writes; + auto rewrite_region = [&](std::vector* regions, Array old_regions) { + for (auto buffer_region : old_regions) { + std::vector region; + for (const auto& range : buffer_region->region) { + region.push_back(rewrite_range(range)); + } + (*regions).emplace_back(buffer_region->buffer, region); + } + }; + rewrite_region(&reads, block->reads); + rewrite_region(&writes, block->writes); + // Generate a new outer block + auto outer_block = Block(/*iter_vars=*/outer_block_vars, // + /*reads=*/reads, // + /*writes=*/writes, // + /*name_hint=*/"blockized_" + block->name_hint, // + /*body=*/std::move(body), // + /*init=*/new_init); + auto outer_realize = BlockRealize(outer_bindings, division.back()[0]->extent, outer_block); + + self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + { + StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_compact_dataflow*/ false); + UpdateScope(self, scope_sref); + } + RecalculateCachedFlags(self.operator->()); + + // } + // TODO(@wuwei): fix affine flags + // self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + // { + // StmtSRef block_sref = self->stmt2ref.at(inner_block.get()); + // UpdateAffineFlag(self, block_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + // /*require_compact_dataflow*/false); + // UpdateScope(self, scope_sref); + // UpdateAffineFlag(self, scope_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // UpdateScope(self, block_sref); + // UpdateAffineFlag(self, block_sref); + // } + + // // Check loop binding + + // { + // struct BindingValidator : public StmtVisitor { + // void VisitStmt_(const BlockRealizeNode* realize) final { + // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); + // UpdateAffineFlag(self, sref); + // VisitStmt(realize->block->body); + // } + // ScheduleState self; + // }; + // BindingValidator validator; + // validator.self = self; + // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(loop_sref).get(), + // nullptr); validator(func->body); + // } + return self->stmt2ref.at(outer_block.get()); +} + +struct BlockizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Blockize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Blockize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("blockize"); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 6e33862c07ca..f9a7c4bab614 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -185,6 +185,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") + .set_body_method(&ScheduleNode::Blockize); /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index da7a2641b162..fc356cf50484 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -356,6 +356,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + /******** Schedule: Annotation ********/ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b35f1b6e17bb..9940749d2f3e 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -87,6 +87,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py new file mode 100644 index 000000000000..57071cd7ad5b --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -0,0 +1,227 @@ +import sys +import pytest +import tvm +from tvm import tir, te +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def blockize(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, (128, 128), "float32") + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(8, 8): + with T.block("blockized_B"): + vi, vj = T.axis.remap("SS", [i, j]) + for ii, jj in T.grid(16, 16): + with T.block("B"): + vii = T.axis.S(128, vi * 16 + ii) + vjj = T.axis.S(128, vj * 16 + jj) + B[vii, vjj] = A[vii, vjj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_1(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 8): + for i1_outer in range(0, 8): + with T.block("blockized_B"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + with T.block("blockized_C"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([C[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for ax0 in range(0, 16): + for ax1 in range(0, 16): + with T.block("C"): + vi = T.axis.S(128, ((vio * 16) + ax0)) + vj = T.axis.S(128, ((vjo * 16) + ax1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_2(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 4): + for i1_outer in range(0, 4): + for ax0 in range(0, 2): + for ax1 in range(0, 2): + with T.block("blockized_B"): + vio = T.axis.S(8, ((i0_outer * 2) + ax0)) + vjo = T.axis.S(8, ((i1_outer * 2) + ax1)) + T.reads( + [A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + T.writes( + [B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_inner_1 in range(0, 32): + for i1_inner_1 in range(0, 32): + with T.block("C"): + vi = T.axis.S(128, ((i0_outer * 32) + i0_inner_1)) + vj = T.axis.S(128, ((i1_outer * 32) + i1_inner_1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def rowsum(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer( + b, + [ + 128, + ], + ) + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_blockized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128]) + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = B[vi] + A[vi, vk] + + +def test_blockize(): + func = elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + _ = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + tvm.ir.assert_structural_equal(blockize, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_schedule(): + func = elementwise + # test 1 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.reverse_compute_at(C, yo) + s.blockize(s.get_loops(C)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 2 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(C) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.compute_at(B, yo) + s.blockize(s.get_loops(B)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 3 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + b_outer = s.blockize(xi) + xC, yC = s.get_loops(C) + xCo, xCi = s.split(xC, factors=[None, 32]) + yCo, yCi = s.split(yC, factors=[None, 32]) + s.reorder(xCo, yCo, xCi, yCi) + s.compute_at(b_outer, yCo) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_init_loops(): + s = tir.Schedule(rowsum, debug_mask="all") + k, _ = s.get_loops(s.get_block("B")) + s.blockize(k) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + verify_trace_roundtrip(sch=s, mod=rowsum) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 37896629916c6d957a5825f9303ed63b69550ae0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 15 Dec 2021 18:04:30 -0500 Subject: [PATCH 02/20] WIP --- python/tvm/tir/schedule/schedule.py | 4 + .../schedule/primitive/blockize_tensorize.cc | 317 +++++++++++------- .../unittest/test_tir_schedule_blockize.py | 39 +-- 3 files changed, 226 insertions(+), 134 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7d352f156a31..14950ac5fefb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1759,6 +1759,10 @@ def after_set_scope( ########## Schedule: Blockize & Tensorize ########## + @type_checked + def blockize(self, loop: LoopRV) -> BlockRV: + return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Annotation ########## @type_checked diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 4b1f3fe411c9..e5261ebffc5b 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -23,10 +23,21 @@ namespace tvm { namespace tir { +/*! + * \brief Detect if bindings are a trivial case of the subspace division where we can divide the + * block iter bindings into two categories: + * 1. The binding covers no inner loop vars. + * 2. The binding covers only inner loop vars. + * This case doesn't require bindings to be quasi-affine. + * + * \param + * \param + * \return The result of the subspace division. + */ Array> TrivialSubspaceDivision(const Array& iter_vars, const Array& bindings, - const std::vector& outer_loops, - const std::vector& inner_loops, + const Array& outer_loops, + const Array& inner_loops, const PrimExpr& predicate) { if (!is_one(predicate)) return {}; std::vector> res; @@ -79,59 +90,72 @@ Array> TrivialSubspaceDivision(const Array& iter return res; } -StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { - /*! - * Check: - * - The sub AST is one-line with only one block - * - * Mutate: - * - extra block var from the only block - * - Update block binding - */ - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); - StmtSRef block_sref = self->stmt2ref.at(block_realize.get()); - arith::Analyzer analyzer; - - // Step 1: Collect loops inside/outside loop_sref - std::vector outer_loops, inner_loops; - std::vector outer_iters, inner_iters; - std::unordered_map iters; - bool inner = true; - Block block = block_realize->block; - for (StmtSRef current_sref = block_sref;;) { - current_sref = GetRef(current_sref->parent); - if (!current_sref.defined()) break; - const auto* current_loop = current_sref->StmtAs(); - if (!current_loop) break; - if (inner) { - inner_loops.push_back(current_loop); - inner_iters.push_back(current_loop->loop_var); - } else { - outer_loops.push_back(current_loop); - outer_iters.push_back(current_loop->loop_var); +Stmt GenerateOuterInitBlock(const Array& inner_block_vars, Block block, + const std::vector& inner_loops, + const Array& inner_bindings, + const Array>& division) { + std::vector init_loops; + Stmt new_init; + std::vector init_block_vars; + std::vector init_block_vars_copy; + std::vector init_bindings; + std::unordered_map binding_replace_map; + std::unordered_map bv_replace_map; + std::unordered_map new_block_vars2old_index; + for (size_t i = 0; i < inner_block_vars.size(); ++i) { + if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { + // copy init block vars and ignore reduce block vars + init_block_vars.push_back(i); + IterVar init_block_var = inner_block_vars[i]; + init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); + init_block_vars_copy.push_back(init_block_var); + bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; + new_block_vars2old_index[init_block_var.get()] = i; } - iters[current_loop->loop_var] = Range::FromMinExtent(current_loop->min, current_loop->extent); - if (current_sref == loop_sref) inner = false; } - - // Step 2: Calculate subspace division for the inner loops - Array> division = arith::SubspaceDivide( - block_realize->iter_values, iters, inner_iters, block_realize->predicate, false, &analyzer); - if (division.empty()) { - // It is possible to blockize if we can not do perfect subspace division if we can divide - // the block var bindings into two categories - // 1. The binding covers no inner loop var - // 2. The binding covers only inner loop vars - division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, outer_iters, - inner_iters, block_realize->predicate); + for (const ForNode* inner_loop : inner_loops) { + for (size_t i = 0; i < init_block_vars.size(); ++i) { + if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], + [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { + // copy loops related to init block vars + For init_loop = GetRef(inner_loop); + init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); + // replace loop vars with copied loop vars + binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; + init_loops.push_back(init_loop); + break; + } + } } - CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + for (size_t i = 0; i < init_block_vars.size(); ++i) { + init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); + } + new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt), + bv_replace_map); + new_init = BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init)); + for (const auto& init_loop : init_loops) { + For new_init_loop = init_loop; + new_init_loop.CopyOnWrite()->body = new_init; + new_init = new_init_loop; + } + return new_init; +} + +// TODO +class SubspaceNotDivisibleError : public ScheduleError {}; - // Step 3: Generate a new inner block +std::array, Array>, 2> GenerateBlockIterVarBindings( + Block block, const Array>& division, + std::unordered_map& bv_iters) { Array inner_block_vars, outer_block_vars; Array inner_bindings, outer_bindings; - std::unordered_map bv_iters; // iter_vars of the inner block for (size_t i = 0; i < block->iter_vars.size(); ++i) { const IterVar& iter_var = block->iter_vars[i]; const arith::IterMapExprNode* outer_binding = @@ -166,16 +190,132 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { bv_iters[iter_var->var] = Range::FromMinExtent(base, division[i][1]->extent); } } - Block inner_block = block; - inner_block.CopyOnWrite()->iter_vars = inner_block_vars; - inner_block.CopyOnWrite()->init = NullOpt; - BlockRealize inner_br = block_realize; - inner_br.CopyOnWrite()->iter_values = inner_bindings; - inner_br.CopyOnWrite()->predicate = division.back()[1]->extent; - inner_br.CopyOnWrite()->block = inner_block; + return {std::make_pair(outer_block_vars, outer_bindings), + std::make_pair(inner_block_vars, inner_bindings)}; +} + +/*! + * \brief A helper to collect the parent loops of the block. The loops are divided into two groups, + * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the + * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its + * successor loops. It is possible that 'outer_loops' is empty. + */ +class LoopSubspaceCollector { + public: + /*! + * \brief Collect the parent loops of the block and store the result in the corresponding fields. + * \param block_sref The sref to the target block. + * \param loop_sref The sref to the separator loop. + */ + void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { + bool inner = true; + for (StmtSRefNode* current_sref = block_sref->parent; + current_sref && current_sref->stmt->IsInstance(); + current_sref = current_sref->parent) { + const auto* current_loop = current_sref->StmtAs(); + ICHECK(current_loop); + if (inner) { + inner_loops.push_back(current_loop); + inner_loop_vars.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_loop_vars.push_back(current_loop->loop_var); + } + loop_var_domain.Set(current_loop->loop_var, + Range::FromMinExtent(current_loop->min, current_loop->extent)); + if (current_sref == loop_sref.get()) inner = false; + } + } + /*! \brief Outer loops which are ancestors of the separator. */ + std::vector outer_loops; + /*! \brief Inner loops which are the separator itself or its successors. */ + std::vector inner_loops; + /*! \brief Loop variables of the outer loops. */ + Array outer_loop_vars; + /*! \brief Loop variables of the inner loops. */ + Array inner_loop_vars; + /*! \brief Domain of the loop variables. */ + Map loop_var_domain; +}; + +/*! + * \brief Check the bindings of the block iters can be divided by a subspace collected by the + * collector. A ScheduleError is raised if the bindings are not divisible by the subspace. + * \param block_realize The block realize to be checked. + * \param collector The collector which has collected the loops of the block. + * \param analyzer The arithmetic analyzer. + * \return The result of the subspace division. + */ +Array> CheckSubspaceDivisible(const BlockRealize& block_realize, + const LoopSubspaceCollector& collector, + arith::Analyzer* analyzer) { + const Block& block = block_realize->block; + + Array> division = + arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, + collector.inner_loop_vars, block_realize->predicate, + /*require_bijective=*/false, analyzer); + + if (division.empty()) { + // If we can't do perfect subspace division, check if it is a trivial case of subspace division. + // In this case, we can still blockize. + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, + collector.outer_loop_vars, collector.inner_loop_vars, + block_realize->predicate); + } + // TODO: raise schedule error + CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + return division; +} + +class BlockizeBlockBuilder { + +}; +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + /*! + * Check: + * - The sub AST is one-line with only one block + * + * Mutate: + * - extra block var from the only block + * - Update block binding + */ + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + arith::Analyzer analyzer; + + // Step 1: Check the loop has a single child BlockRealize on the sref tree. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); + Block block = block_realize->block; + StmtSRef block_sref = self->stmt2ref.at(block.get()); + + // Step 2: Collect loops inside and outside loop_sref. + LoopSubspaceCollector collector; + collector.Collect(block_sref, loop_sref); + + // Step 3: Calculate subspace division for the inner loops. + Array> division = + CheckSubspaceDivisible(block_realize, collector, &analyzer); + + // Step 4: Generate a new inner block + // Step 4.1: Compute outer/inner block bindings + std::unordered_map + bv_iters; // iter_vars of the inner block + auto r = GenerateBlockIterVarBindings(block, division, bv_iters); + Array outer_block_vars = r[0].first; + Array outer_bindings = r[0].second; + + auto* inner_br = block_realize.CopyOnWrite(); + auto* inner_block = inner_br->block.CopyOnWrite(); + { + inner_br->iter_values = r[1].second; + inner_br->predicate = division.back()[1]->extent; + inner_block->iter_vars = r[1].first; + inner_block->init = NullOpt; + } + // Regenerate inner_loops - Stmt body = inner_br; - for (const auto& inner_loop : inner_loops) { + Stmt body = GetRef(inner_br); + for (const auto& inner_loop : collector.inner_loops) { auto loop_node = make_object(*inner_loop); loop_node->body = body; body = For(loop_node); @@ -183,58 +323,10 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { // Regenerate init for outer block Optional new_init = NullOpt; if (block->init.defined()) { - std::vector init_loops; - std::vector init_block_vars; - std::vector init_block_vars_copy; - std::vector init_bindings; - std::unordered_map binding_replace_map; - std::unordered_map bv_replace_map; - std::unordered_map new_block_vars2old_index; - for (size_t i = 0; i < inner_block_vars.size(); ++i) { - if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && - UsesVar(block->init.value(), - [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { - // copy init block vars and ignore reduce block vars - init_block_vars.push_back(i); - IterVar init_block_var = inner_block_vars[i]; - init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); - init_block_vars_copy.push_back(init_block_var); - bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; - new_block_vars2old_index[init_block_var.get()] = i; - } - } - for (const ForNode* inner_loop : inner_loops) { - for (size_t i = 0; i < init_block_vars.size(); ++i) { - if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], - [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { - // copy loops related to init block vars - For init_loop = GetRef(inner_loop); - init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); - // replace loop vars with copied loop vars - binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; - init_loops.push_back(init_loop); - break; - } - } - } - for (size_t i = 0; i < init_block_vars.size(); ++i) { - init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); - } - new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // - /*reads=*/{}, // - /*writes=*/block->writes, // - /*name_hint=*/block->name_hint + "_init", // - /*body=*/block->init.value(), // - /*init=*/NullOpt), - bv_replace_map); - new_init = - BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init.value())); - for (const auto& init_loop : init_loops) { - For new_init_loop = init_loop; - new_init_loop.CopyOnWrite()->body = new_init.value(); - new_init = new_init_loop; - } + new_init = GenerateOuterInitBlock(inner_block->iter_vars, block, collector.inner_loops, + inner_br->iter_values, division); } + // Calculate outer block's IO region auto rewrite_range = [&](const Range& range) -> Range { const Array& res = @@ -269,14 +361,15 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { /*init=*/new_init); auto outer_realize = BlockRealize(outer_bindings, division.back()[0]->extent, outer_block); - self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + // Step x: do actual replace + self->Replace(loop_sref, outer_realize, {{block, GetRef(inner_block)}}); { StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, /*require_compact_dataflow*/ false); - UpdateScope(self, scope_sref); + // UpdateScope(self, scope_sref); } - RecalculateCachedFlags(self.operator->()); + // RecalculateCachedFlags(self.operator->()); // } // TODO(@wuwei): fix affine flags diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 57071cd7ad5b..8629ba2f4f9d 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -42,13 +42,13 @@ def blockize(a: T.handle, c: T.handle) -> None: @T.prim_func def blockize_schedule_1(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + C = T.match_buffer(c, [128, 128]) + A = T.match_buffer(a, [128, 128]) # body with T.block("root"): T.reads([]) T.writes([]) - B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([128, 128], "float32") for i0_outer in range(0, 8): for i1_outer in range(0, 8): with T.block("blockized_B"): @@ -81,13 +81,13 @@ def blockize_schedule_1(a: T.handle, c: T.handle) -> None: @T.prim_func def blockize_schedule_2(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + C = T.match_buffer(c, [128, 128]) + A = T.match_buffer(a, [128, 128]) # body with T.block("root"): T.reads([]) T.writes([]) - B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([128, 128]) for i0_outer in range(0, 4): for i1_outer in range(0, 4): for ax0 in range(0, 2): @@ -122,12 +122,7 @@ def blockize_schedule_2(a: T.handle, c: T.handle) -> None: @T.prim_func def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) - B = T.match_buffer( - b, - [ - 128, - ], - ) + B = T.match_buffer(b, [128]) for k, i in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [k, i]) @@ -157,7 +152,7 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: def test_blockize(): func = elementwise # schedule - s = tir.Schedule(func, debug_mask="all") + s = tir.Schedule(func, debug_mask="none") B = s.get_block("B") _ = s.get_block("C") x, y = s.get_loops(B) @@ -166,13 +161,13 @@ def test_blockize(): s.reorder(xo, yo, xi, yi) s.blockize(xi) tvm.ir.assert_structural_equal(blockize, s.mod["main"]) - verify_trace_roundtrip(sch=s, mod=func) + # verify_trace_roundtrip(sch=s, mod=func) def test_blockize_schedule(): func = elementwise # test 1 - s = tir.Schedule(func, debug_mask="all") + s = tir.Schedule(func, debug_mask="none") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(B) @@ -183,9 +178,9 @@ def test_blockize_schedule(): s.reverse_compute_at(C, yo) s.blockize(s.get_loops(C)[-2]) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) - verify_trace_roundtrip(sch=s, mod=func) + # verify_trace_roundtrip(sch=s, mod=func) # test 2 - s = tir.Schedule(func, debug_mask="all") + s = tir.Schedule(func, debug_mask="none") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(C) @@ -196,9 +191,9 @@ def test_blockize_schedule(): s.compute_at(B, yo) s.blockize(s.get_loops(B)[-2]) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) - verify_trace_roundtrip(sch=s, mod=func) + # verify_trace_roundtrip(sch=s, mod=func) # test 3 - s = tir.Schedule(func, debug_mask="all") + s = tir.Schedule(func, debug_mask="none") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(B) @@ -212,15 +207,15 @@ def test_blockize_schedule(): s.reorder(xCo, yCo, xCi, yCi) s.compute_at(b_outer, yCo) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) - verify_trace_roundtrip(sch=s, mod=func) + # verify_trace_roundtrip(sch=s, mod=func) def test_blockize_init_loops(): - s = tir.Schedule(rowsum, debug_mask="all") + s = tir.Schedule(rowsum, debug_mask="none") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) - verify_trace_roundtrip(sch=s, mod=rowsum) + # verify_trace_roundtrip(sch=s, mod=rowsum) if __name__ == "__main__": From 4b9e6e19469ecb01ece37a8ecdaa0a8df439ed5f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 17 Dec 2021 16:06:52 -0500 Subject: [PATCH 03/20] WIP --- .../schedule/primitive/blockize_tensorize.cc | 488 +++++++++--------- .../unittest/test_tir_schedule_blockize.py | 44 +- 2 files changed, 286 insertions(+), 246 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index e5261ebffc5b..37a8fe522546 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -16,37 +16,41 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../../arith/pattern_match.h" -#include "../../ir/functor_common.h" +#include + #include "../utils.h" namespace tvm { namespace tir { /*! - * \brief Detect if bindings are a trivial case of the subspace division where we can divide the + * \brief Detect if bindings are a trivial case of the subspace division where we can divide the * block iter bindings into two categories: * 1. The binding covers no inner loop vars. * 2. The binding covers only inner loop vars. - * This case doesn't require bindings to be quasi-affine. * - * \param - * \param + * The bindings are not required to be quasi-affine. + * + * \param iter_vars The input iterators + * \param bindings The values of iter_vars + * \param outer_loops Iterators outside the subspace. + * \param inner_loops Iterators of the subspace + * \param predicate The predicate constaints on the input iterators. * \return The result of the subspace division. */ Array> TrivialSubspaceDivision(const Array& iter_vars, const Array& bindings, - const Array& outer_loops, - const Array& inner_loops, + const Array& outer_iters, + const Array& inner_iters, const PrimExpr& predicate) { if (!is_one(predicate)) return {}; std::vector> res; std::unordered_set outer_loop_vars; std::unordered_set inner_loop_vars; - for (const Var& var : outer_loops) { + for (const Var& var : outer_iters) { outer_loop_vars.insert(var.get()); } - for (const Var& var : inner_loops) { + for (const Var& var : inner_iters) { inner_loop_vars.insert(var.get()); } for (size_t i = 0; i < bindings.size(); ++i) { @@ -54,33 +58,27 @@ Array> TrivialSubspaceDivision(const Array& iter bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); bool inner = UsesVar( bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); - bool is_var = bindings[i]->IsInstance(); + arith::IterMark iter_mark; + if (bindings[i]->IsInstance()) { + iter_mark = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } if (outer && !inner) { arith::IterMark outer{nullptr}; - if (is_var) { - outer = arith::IterMark( - arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), - iter_vars[i]->dom->extent); - } else { - outer = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); - } - arith::IterMark inner(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer, inner})); + const auto& outer_iter = iter_mark; + arith::IterMark inner_iter(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer_iter, inner_iter})); } else if (inner && !outer) { - arith::IterMark inner{nullptr}; - if (is_var) { - inner = arith::IterMark( - arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), - iter_vars[i]->dom->extent); - } else { - inner = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); - } - arith::IterMark outer(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer, inner})); + const auto& inner_iter = iter_mark; + arith::IterMark outer_iter(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer_iter, inner_iter})); } else if (!outer && !inner) { - arith::IterMark outer(arith::IterSumExpr({}, 0), 1); - arith::IterMark inner(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer, inner})); + arith::IterMark outer_iter(arith::IterSumExpr({}, 0), 1); + arith::IterMark inner_iter(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer_iter, inner_iter})); } else { return {}; } @@ -90,108 +88,88 @@ Array> TrivialSubspaceDivision(const Array& iter return res; } -Stmt GenerateOuterInitBlock(const Array& inner_block_vars, Block block, - const std::vector& inner_loops, - const Array& inner_bindings, - const Array>& division) { - std::vector init_loops; - Stmt new_init; - std::vector init_block_vars; - std::vector init_block_vars_copy; - std::vector init_bindings; - std::unordered_map binding_replace_map; - std::unordered_map bv_replace_map; - std::unordered_map new_block_vars2old_index; - for (size_t i = 0; i < inner_block_vars.size(); ++i) { - if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && +class SubspaceNotDivisibleError : public ScheduleError {}; + +/*! + * \brief Regenerate outer loops of a statement + * \param + */ +Stmt RegenerateLoops(const std::vector& loops, Stmt body) { + for (const ForNode* loop : loops) { + ObjectPtr new_loop = make_object(*loop); + new_loop->body = std::move(body); + body = For(new_loop); + } + return body; +} + +Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, + const std::vector& inner_loops) { + Array init_block_iters; + Array init_bindings; + const Block& inner_block = inner_block_realize->block; + + // Step 1: Collect data-parallel block iters + for (size_t i = 0; i < inner_block->iter_vars.size(); i++) { + const IterVar& iter_var = inner_block->iter_vars[i]; + const PrimExpr& binding = inner_block_realize->iter_values[i]; + if (iter_var->iter_type == IterVarType::kDataPar && UsesVar(block->init.value(), - [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { - // copy init block vars and ignore reduce block vars - init_block_vars.push_back(i); - IterVar init_block_var = inner_block_vars[i]; - init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); - init_block_vars_copy.push_back(init_block_var); - bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; - new_block_vars2old_index[init_block_var.get()] = i; + [&iter_var](const VarNode* var) { return var == iter_var->var.get(); })) { + init_block_iters.push_back(iter_var); + init_bindings.push_back(binding); } } + + // Step 2: Collect loops related to iters of the init block + std::vector init_loops; for (const ForNode* inner_loop : inner_loops) { - for (size_t i = 0; i < init_block_vars.size(); ++i) { - if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], - [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { - // copy loops related to init block vars - For init_loop = GetRef(inner_loop); - init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); - // replace loop vars with copied loop vars - binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; - init_loops.push_back(init_loop); - break; + for (const PrimExpr& init_binding : init_bindings) { + if (UsesVar(init_binding, + [inner_loop](const VarNode* var) { return var == inner_loop->loop_var.get(); })) { + init_loops.push_back(inner_loop); } } } - for (size_t i = 0; i < init_block_vars.size(); ++i) { - init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); - } - new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // - /*reads=*/{}, // - /*writes=*/block->writes, // - /*name_hint=*/block->name_hint + "_init", // - /*body=*/block->init.value(), // - /*init=*/NullOpt), - bv_replace_map); - new_init = BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init)); - for (const auto& init_loop : init_loops) { - For new_init_loop = init_loop; - new_init_loop.CopyOnWrite()->body = new_init; - new_init = new_init_loop; + + // Step 3: Create new block iters for the init block + Map subst_map; + for (size_t i = 0; i < init_block_iters.size(); i++) { + IterVar new_iter_var = init_block_iters[i]; + auto* new_init_var_node = new_iter_var.CopyOnWrite(); + Var old_var = new_iter_var->var; + new_init_var_node->var = old_var.copy_with_suffix("_init"); + subst_map.Set(old_var, new_iter_var->var); + init_block_iters.Set(i, std::move(new_iter_var)); } - return new_init; -} -// TODO -class SubspaceNotDivisibleError : public ScheduleError {}; + // Step 4: Generate loop nests and the init block + Block init_block{/*iter_vars=*/init_block_iters, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt}; + Stmt new_init = BlockRealize( + /*iter_values=*/init_bindings, + /*predicate=*/inner_block_realize->predicate, + /*block=*/ std::move(init_block) + ); -std::array, Array>, 2> GenerateBlockIterVarBindings( - Block block, const Array>& division, - std::unordered_map& bv_iters) { - Array inner_block_vars, outer_block_vars; - Array inner_bindings, outer_bindings; - for (size_t i = 0; i < block->iter_vars.size(); ++i) { - const IterVar& iter_var = block->iter_vars[i]; - const arith::IterMapExprNode* outer_binding = - division[i][0]->source.as(); - const arith::IterMapExprNode* inner_binding = - division[i][1]->source.as(); - ICHECK(outer_binding); - ICHECK(inner_binding); - if (is_one(division[i][1]->extent)) { // IsOuter - // extract this iter var to outer block directly - outer_bindings.push_back( - arith::NormalizeIterMapToExpr(GetRef(outer_binding))); - outer_block_vars.push_back(iter_var); - // bv_iters[iter_var->var] = Range::FromMinExtent(0, division[i][0]->extent); - } else { - const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), - iter_var->var.copy_with_suffix("o"), iter_var->iter_type); - outer_bindings.push_back( - arith::NormalizeIterMapToExpr(GetRef(outer_binding))); - outer_block_vars.push_back(outer_var); - // generate a new iter var for outer block - PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; - if (const auto* op = division[i][1]->source.as()) { - base = base + op->base; - inner_bindings.push_back(base + - arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); - } else { - inner_bindings.push_back( - base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); - } - inner_block_vars.push_back(iter_var); - bv_iters[iter_var->var] = Range::FromMinExtent(base, division[i][1]->extent); - } + // Step 5: Generate the parent loops for the init block + for (const ForNode* init_loop : init_loops) { + ObjectPtr new_loop = make_object(*init_loop); + new_loop->loop_var = init_loop->loop_var.copy_with_suffix(""); + subst_map.Set(init_loop->loop_var, new_loop->loop_var); + new_loop->body = std::move(new_init); + new_init = For(new_loop); } - return {std::make_pair(outer_block_vars, outer_bindings), - std::make_pair(inner_block_vars, inner_bindings)}; + + // Step 6: Substitute with new loop variables and block iters to prevent duplication of + // variables in the outer block. + new_init = Substitute(new_init, subst_map); + + return new_init; } /*! @@ -240,11 +218,12 @@ class LoopSubspaceCollector { /*! * \brief Check the bindings of the block iters can be divided by a subspace collected by the - * collector. A ScheduleError is raised if the bindings are not divisible by the subspace. + * collector. * \param block_realize The block realize to be checked. * \param collector The collector which has collected the loops of the block. * \param analyzer The arithmetic analyzer. * \return The result of the subspace division. + * \throws ScheduleError If the bindings are not divisible by the subspace. */ Array> CheckSubspaceDivisible(const BlockRealize& block_realize, const LoopSubspaceCollector& collector, @@ -268,9 +247,126 @@ Array> CheckSubspaceDivisible(const BlockRealize& block_r return division; } -class BlockizeBlockBuilder { - +class BlockizedBindingExtractor { + public: + void ExtractBindings(const Array& iter_vars, + const Array>& division) { + ICHECK(iter_vars.size() + 1 == division.size()); + for (size_t i = 0; i < iter_vars.size(); ++i) { + const IterVar& iter_var = iter_vars[i]; + const arith::IterMapExprNode* outer_binding = + division[i][0]->source.as(); + const arith::IterMapExprNode* inner_binding = + division[i][1]->source.as(); + ICHECK(outer_binding); + ICHECK(inner_binding); + + // After computing the subspace division, bindings[i] can be written as + // outer_binding * inner_binding->extent + inner_binding + // The outer block will have binding: iter_outer -> outer_binding + // The inner block will have binding: iter_inner -> iter_outer * inner_binding->extent + + // inner_binding + + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(iter_var); + } else { + const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), + iter_var->var.copy_with_suffix("o"), iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(outer_var); + // generate a new iter var for outer block + // TODO: add test case outer extent is zero + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + if (const auto* op = division[i][1]->source.as()) { + base = base + op->base; + inner_bindings.push_back(base + + arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); + } else { + inner_bindings.push_back( + base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + } + inner_iter_vars.push_back(iter_var); + // bv_iter: inner block iter -> division inner extent + inner_iter_relaxed_range.Set(iter_var->var, + Range::FromMinExtent(base, division[i][1]->extent)); + } + } + } + /*! \brief Iters of the outer block. */ + Array outer_iter_vars; + /*! \brief Iters of the outer block. */ + Array inner_iter_vars; + /*! \brief Binding values of the outer block. */ + Array outer_bindings; + /*! \brief Binding values of the inner block. */ + Array inner_bindings; + + /*! \brief The range of the inner block iters Note that this is different from the domain of the + * inner block iters. */ + Map inner_iter_relaxed_range; +}; + + +/*! + * \brief + */ +BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, + const Map& inner_iter_relaxed_range, + arith::Analyzer* analyzer) { + Array new_region; + new_region.reserve(buffer_region->region.size()); + for (const auto& range : buffer_region->region) { + const Array& res = + arith::DetectIterMap({range->min}, inner_iter_relaxed_range, true, false, analyzer); + ICHECK_EQ(res.size(), 1); + const arith::IterSumExpr& normalized_expr = res[0]; + PrimExpr extent = 1; + if (normalized_expr->args.size() == 1) { + ICHECK(analyzer->CanProve(normalized_expr->args[0]->scale - range->extent == 0)); + extent = normalized_expr->args[0]->extent; + } + new_region.push_back(Range::FromMinExtent(normalized_expr->base, extent * range->extent)); + } + return BufferRegion(buffer_region->buffer, std::move(new_region)); }; + +BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, + const Block& block, BlockRealize inner_block_realize, + const std::vector& inner_loops, + PrimExpr predicate, arith::Analyzer* analyzer) { + // Step 1: Generate the init block if needed + Optional new_init = NullOpt; + if (block->init.defined()) { + new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops); + } + + // Step 2: Compute the access regions of the outer block by relaxing the inner loops + Array new_reads = block->reads; + Array new_writes = block->writes; + + auto f_mutate = [&](const BufferRegion& buffer_region) { + return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_relaxed_range, analyzer); + }; + new_reads.MutateByApply(f_mutate); + new_writes.MutateByApply(f_mutate); + + Stmt outer_block_body = RegenerateLoops(inner_loops, inner_block_realize); + Block outer_block{/*iter_vars=*/extractor.outer_iter_vars, // + /*reads=*/new_reads, // + /*writes=*/new_writes, // + /*name_hint=*/"blockized_" + block->name_hint, // + /*body=*/std::move(outer_block_body), // + /*init=*/new_init}; + BlockRealize outer_block_realize{/*iter_values=*/extractor.outer_bindings, + /*predicate=*/std::move(predicate), + /*block=*/std::move(outer_block)}; + return outer_block_realize; +} + StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { /*! * Check: @@ -296,118 +392,37 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { Array> division = CheckSubspaceDivisible(block_realize, collector, &analyzer); - // Step 4: Generate a new inner block - // Step 4.1: Compute outer/inner block bindings - std::unordered_map - bv_iters; // iter_vars of the inner block - auto r = GenerateBlockIterVarBindings(block, division, bv_iters); - Array outer_block_vars = r[0].first; - Array outer_bindings = r[0].second; - - auto* inner_br = block_realize.CopyOnWrite(); - auto* inner_block = inner_br->block.CopyOnWrite(); - { - inner_br->iter_values = r[1].second; - inner_br->predicate = division.back()[1]->extent; - inner_block->iter_vars = r[1].first; - inner_block->init = NullOpt; - } + // Step 4: Generate bindings for the outer block and the inner block based on the result of + // the subspace division. + BlockizedBindingExtractor extractor; + extractor.ExtractBindings(block->iter_vars, division); + const PrimExpr& outer_pred = division.back()[0]->extent; + const PrimExpr& inner_pred = division.back()[1]->extent; - // Regenerate inner_loops - Stmt body = GetRef(inner_br); - for (const auto& inner_loop : collector.inner_loops) { - auto loop_node = make_object(*inner_loop); - loop_node->body = body; - body = For(loop_node); - } - // Regenerate init for outer block - Optional new_init = NullOpt; - if (block->init.defined()) { - new_init = GenerateOuterInitBlock(inner_block->iter_vars, block, collector.inner_loops, - inner_br->iter_values, division); - } + // Step 5: Generate the inner block. + BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); + BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); + inner_block_realize->iter_values = extractor.inner_bindings; + inner_block_realize->predicate = inner_pred; + inner_block->iter_vars = extractor.inner_iter_vars; + inner_block->init = NullOpt; - // Calculate outer block's IO region - auto rewrite_range = [&](const Range& range) -> Range { - const Array& res = - arith::DetectIterMap({range->min}, bv_iters, true, false, &analyzer); - ICHECK_EQ(res.size(), 1); - const arith::IterSumExpr& normalized_expr = res[0]; - PrimExpr extent = 1; - if (normalized_expr->args.size() == 1) { - CHECK(analyzer.CanProve(normalized_expr->args[0]->scale - range->extent == 0)); - extent = normalized_expr->args[0]->extent; - } - return Range::FromMinExtent(normalized_expr->base, extent * range->extent); - }; - std::vector reads, writes; - auto rewrite_region = [&](std::vector* regions, Array old_regions) { - for (auto buffer_region : old_regions) { - std::vector region; - for (const auto& range : buffer_region->region) { - region.push_back(rewrite_range(range)); - } - (*regions).emplace_back(buffer_region->buffer, region); - } - }; - rewrite_region(&reads, block->reads); - rewrite_region(&writes, block->writes); - // Generate a new outer block - auto outer_block = Block(/*iter_vars=*/outer_block_vars, // - /*reads=*/reads, // - /*writes=*/writes, // - /*name_hint=*/"blockized_" + block->name_hint, // - /*body=*/std::move(body), // - /*init=*/new_init); - auto outer_realize = BlockRealize(outer_bindings, division.back()[0]->extent, outer_block); - - // Step x: do actual replace + // Step 6: Generate the outer block. + BlockRealize outer_realize = + GenerateBlockizedOuterBlock(extractor, block, GetRef(inner_block_realize), + collector.inner_loops, outer_pred, &analyzer); + // Step 7: Do the actual replacement self->Replace(loop_sref, outer_realize, {{block, GetRef(inner_block)}}); - { - StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, - /*require_compact_dataflow*/ false); - // UpdateScope(self, scope_sref); - } - // RecalculateCachedFlags(self.operator->()); - - // } - // TODO(@wuwei): fix affine flags - // self->Replace(loop_sref, outer_realize, {{block, inner_block}}); - // { - // StmtSRef block_sref = self->stmt2ref.at(inner_block.get()); - // UpdateAffineFlag(self, block_sref); - // } - // { - // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); - // StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, - // /*require_compact_dataflow*/false); - // UpdateScope(self, scope_sref); - // UpdateAffineFlag(self, scope_sref); - // } - // { - // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); - // UpdateScope(self, block_sref); - // UpdateAffineFlag(self, block_sref); - // } - - // // Check loop binding - - // { - // struct BindingValidator : public StmtVisitor { - // void VisitStmt_(const BlockRealizeNode* realize) final { - // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); - // UpdateAffineFlag(self, sref); - // VisitStmt(realize->block->body); - // } - // ScheduleState self; - // }; - // BindingValidator validator; - // validator.self = self; - // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(loop_sref).get(), - // nullptr); validator(func->body); - // } - return self->stmt2ref.at(outer_block.get()); + + // Step 8: Update the cached flags + const StmtSRef& outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); + BlockInfo& outer_block_info = self->block_info[outer_block_sref]; + const BlockInfo& inner_block_info = self->block_info.at(block_sref); + outer_block_info.affine_binding = inner_block_info.affine_binding; + outer_block_info.region_cover = inner_block_info.region_cover; + outer_block_info.scope->stage_pipeline = inner_block_info.scope->stage_pipeline; + + return outer_block_sref; } struct BlockizeTraits : public UnpackedInstTraits { @@ -430,6 +445,7 @@ struct BlockizeTraits : public UnpackedInstTraits { return py.Str(); } + template friend struct UnpackedInstTraits; }; diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 8629ba2f4f9d..c1fa7efbf827 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -7,7 +7,30 @@ @T.prim_func -def elementwise(a: T.handle, c: T.handle) -> None: +def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + +@T.prim_func +def single_elementwise_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: + with T.block("blockized_B"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + T.reads(A[0 : 128, 0 : 128]) + T.writes(B[0 : 128, 0 : 128]) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + +@T.prim_func +def two_elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) @@ -152,7 +175,7 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: def test_blockize(): func = elementwise # schedule - s = tir.Schedule(func, debug_mask="none") + s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") _ = s.get_block("C") x, y = s.get_loops(B) @@ -161,13 +184,13 @@ def test_blockize(): s.reorder(xo, yo, xi, yi) s.blockize(xi) tvm.ir.assert_structural_equal(blockize, s.mod["main"]) - # verify_trace_roundtrip(sch=s, mod=func) + verify_trace_roundtrip(sch=s, mod=func) def test_blockize_schedule(): func = elementwise # test 1 - s = tir.Schedule(func, debug_mask="none") + s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(B) @@ -178,9 +201,9 @@ def test_blockize_schedule(): s.reverse_compute_at(C, yo) s.blockize(s.get_loops(C)[-2]) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) - # verify_trace_roundtrip(sch=s, mod=func) + verify_trace_roundtrip(sch=s, mod=func) # test 2 - s = tir.Schedule(func, debug_mask="none") + s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(C) @@ -191,9 +214,9 @@ def test_blockize_schedule(): s.compute_at(B, yo) s.blockize(s.get_loops(B)[-2]) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) - # verify_trace_roundtrip(sch=s, mod=func) + verify_trace_roundtrip(sch=s, mod=func) # test 3 - s = tir.Schedule(func, debug_mask="none") + s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") C = s.get_block("C") x, y = s.get_loops(B) @@ -207,13 +230,14 @@ def test_blockize_schedule(): s.reorder(xCo, yCo, xCi, yCi) s.compute_at(b_outer, yCo) tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) - # verify_trace_roundtrip(sch=s, mod=func) + verify_trace_roundtrip(sch=s, mod=func) def test_blockize_init_loops(): - s = tir.Schedule(rowsum, debug_mask="none") + s = tir.Schedule(rowsum, debug_mask="all") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) + print(s.mod['main'].script()) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) # verify_trace_roundtrip(sch=s, mod=rowsum) From 292a75115ca52cd80b4749cc1ddcea3943e744a0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 4 Jan 2022 15:21:24 -0500 Subject: [PATCH 04/20] test cases --- include/tvm/tir/function.h | 50 ++ include/tvm/tir/schedule/schedule.h | 19 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/function.py | 27 ++ python/tvm/tir/schedule/schedule.py | 80 +++- src/arith/int_set.cc | 2 +- src/tir/ir/function.cc | 68 +++ src/tir/schedule/analysis.h | 68 +++ src/tir/schedule/analysis/analysis.cc | 328 +++++++++++++ src/tir/schedule/concrete_schedule.cc | 14 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 14 + .../schedule/primitive/blockize_tensorize.cc | 313 +++++++++--- src/tir/schedule/schedule.cc | 12 + src/tir/schedule/state.cc | 5 +- src/tir/schedule/traced_schedule.cc | 14 + src/tir/schedule/traced_schedule.h | 2 + src/tir/schedule/transform.cc | 47 ++ src/tir/schedule/transform.h | 8 + .../unittest/test_tir_schedule_blockize.py | 192 ++++---- .../unittest/test_tir_schedule_tensorize.py | 446 ++++++++++++++++++ 21 files changed, 1530 insertions(+), 183 deletions(-) create mode 100644 tests/python/unittest/test_tir_schedule_tensorize.py diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e482a18c4a5b..3f59066ffa6b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,56 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Tensor intrinsics for tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc description; + /*! \brief The intrinsic function for lower-level implementation. */ + PrimFunc implementation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("description", &description); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +/*! + * \brief Managed reference to TensorIntrinNode. + */ +class TensorIntrin : public ObjectRef { + public: + /*! + * \brief Constructor + * \param desc_func The function to describe the computation. + * \param intrin_func The intrinsic function for lower-level implementation. + */ + TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); + + /*! + * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked + * up with its name. + * \param name The name of the TensorIntrin to register + * \param desc_func The function to describe the computation. + * \param intrin_func The intrinsic function for lower-level implementation. + * \return The created TensorIntrin. + */ + TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); + + /*! + * \brief Look up TensorIntrin by name. Raises an exception if not found. + * \param name The name of the TensorIntrin. + */ + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 411b20d6a855..58d57222b439 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -474,11 +474,24 @@ class ScheduleNode : public runtime::Object { virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! - * \brief Convert the subtree rooted by a specific loop into a block. - * \param loop_rv The root of the subtree - * \return The new block + * \brief Convert the subtree rooted at a specific loop into a block. + * \param loop_rv the root of the subtree + * \return the new block */ virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv the loop/block to be tensorized + * \param intrin the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv The loop/block to be tensorized + * \param intrin_name Name of the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0; + /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 07ceb29ebf98..5854b9369c16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc +from .function import PrimFunc, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ecbcd837cb72..ed693a9f3c49 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -162,3 +162,30 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: return tvm._ffi.get_global_func("script.AsTVMScript")( self, tir_prefix, show_meta ) # type: ignore + + +@tvm._ffi.register_object("tir.TensorIntrin") +class TensorIntrin(Object): + """A tensor intrinsic. + + Parameters + ---------- + desc_func: PrimFunc + The function to describe the computation + + intrin_func: PrimFunc + The function for execution + """ + + def __init__(self, desc_func, intrin_func): + self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc_func, intrin_func) + + @staticmethod + def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc): + return _ffi_api.TensorIntrinRegister( # type: ignore + name, desc_func, intrin_func + ) + + @staticmethod + def get(name: str): + return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 14950ac5fefb..f0d33f2e7c26 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -21,7 +21,7 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, TensorIntrin from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -1761,8 +1761,86 @@ def after_set_scope( @type_checked def blockize(self, loop: LoopRV) -> BlockRV: + """Convert the subtree rooted at a specific loop into a block. + + Parameters + ---------- + loop : LoopRV + The root of the subtree. + + Returns + ------- + result : BlockRV + The new block. + + Examples + -------- + + Before blockize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + Create the schedule and do set_scope: + + .. code-block:: python + + sch = tir.Schedule(before_blockize) + B = sch.get_block("B") + _, _, i1, _ = sch.get_loops(B) + sch.blockize(i1) + print(sch.mod["main"].script()) + + After applying blockize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + )-> None: + for i_0, j_0 in T.grid(8, 8): + with T.block("blockized_B"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi = T.axis.spatial(128, vio * 16 + i_1) + vj = T.axis.spatial(128, vjo * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + Note + ---- + blockize requires there is exactly one block under the given loop and the bindings of the + block are divisible by the subspace represented by the loops starting at the given loop. + + """ + return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member + @type_checked + def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin]) -> None: + if isinstance(tensor_intrin, str): + tensor_intrin = String(tensor_intrin) + _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member + self, loop, tensor_intrin) + ########## Schedule: Annotation ########## @type_checked diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 55a1a5a1830e..3d30eef99d7d 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -511,7 +511,7 @@ Range IntSet::CoverRange(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); ICHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::FromMinExtent(s_int->min_value, + return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52ea1..8d99c864fa49 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -64,6 +64,65 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +class TensorIntrinManager { + public: + Map reg; + + static TensorIntrinManager* Global() { + static TensorIntrinManager* inst = new TensorIntrinManager(); + return inst; + } +}; + +TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) { + // check the number of func var is equal + CHECK_EQ(desc_func->params.size(), intrin_func->params.size()); + CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size()); + + // check both functions' bodies are directly block + const auto* desc_realize = Downcast(desc_func->body)->block->body.as(); + const auto* intrin_realize = Downcast(intrin_func->body)->block->body.as(); + CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; + CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block"; + + const Block& desc_block = desc_realize->block; + const Block& intrin_block = intrin_realize->block; + + // check block var number and iter type + CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size()) + << "Two blocks should have the same number of block vars"; + for (size_t i = 0; i < desc_block->iter_vars.size(); i++) { + const IterVar& desc_var = desc_block->iter_vars[i]; + const IterVar& intrin_var = intrin_block->iter_vars[i]; + CHECK(desc_var->iter_type == intrin_var->iter_type) + << "Block iter_type mismatch between " << desc_var->iter_type << " and " + << intrin_var->iter_type; + } + + auto n = make_object(); + n->description = std::move(desc_func); + n->implementation = std::move(intrin_func); + data_ = std::move(n); +} + +TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc intrin_func) { + TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 0) + << "ValueError: TensorIntrin '" << name << "' has already been registered"; + TensorIntrin intrin(desc_func, intrin_func); + manager->reg.Set(name, intrin); + return intrin; +} + +TensorIntrin TensorIntrin::Get(String name) { + const TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 1) + << "ValueError: TensorIntrin '" << name << "' is not registered"; + return manager->reg.at(name); +} + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // TODO(tvm-team) redirect to Text printer once we have a good text format. @@ -85,5 +144,14 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrin") + .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 636cc7d0a5db..2971493363db 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -558,6 +558,74 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // int64_t max_parallel_extent, // int64_t max_parallel_basic); +/******** Tensorization Related ********/ + +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +/* \brief Deep comparison to check if two IR ASTs are equivalent */ +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {} + + // Map from rhs buffer to lhs buffer + std::unordered_map rhs_buffer_map_; + // Buffer indices mapping + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + std::vector extra_block_vars_; + // variable remap if any + std::unordered_map equal_map_; + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareType(const DataType& lhs, const DataType& rhs); + + protected: + bool assert_mode_; + bool is_scope_block = true, is_inner_block = true; +}; + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index be5e55d4ec70..918caa5091a7 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1659,6 +1659,334 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } catch (...) { throw InvalidStorageScopeError(self->mod, std::move(storage_scope)); } +}; + +/******** Tensorize Comparator ********/ + +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + if (n.same_as(other)) return true; + if (n->type_index() != other->type_index()) return false; + bool equal = StmtComparator::VisitStmt(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << other; + return equal; +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (!VisitStmt(op->body, rhs->body)) return false; + if (op->kind != rhs->kind) return false; + if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) + return false; + return CompareAnnotationMap(op->annotations, rhs->annotations); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Skip Compare binding values if the block is scope block (the outermost one). + if (!is_scope_block) { + size_t offset = op->iter_values.size() - rhs->iter_values.size(); + if (rhs->iter_values.size() > op->iter_values.size()) return false; + if (is_inner_block) { + // weak pattern matching for the inner block (the son of the scope block) + // where the pattern is v + iter <=> expr + iter + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + PrimExpr lhs_expr, rhs_expr; + Optional lhs_iter, rhs_iter; + auto detect = [](const PrimExpr& binding) -> std::pair> { + arith::PVar expr; + arith::PVar iter; + if (iter.Match(binding)) { + return std::make_pair(0, iter.Eval()); + } else if ((expr + iter).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else if ((iter + expr).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else { + return std::make_pair(expr.Eval(), NullOpt); + } + }; + std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); + std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); + CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; + if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); + if (is_zero(rhs_expr)) { + CHECK(is_zero(lhs_expr)) << "Incompatible binding"; + } else { + const auto* bv = rhs_expr.as(); + if (!bv) { + VisitExpr(lhs_expr, rhs_expr); + } else { + auto it = equal_map_.find(GetRef(bv)); + if (it == equal_map_.end()) { + equal_map_[GetRef(bv)] = lhs_expr; + } else { + CHECK(it->second->IsInstance()); + VisitExpr(lhs_expr, Downcast(it->second)); + } + } + } + } + } else { + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; + } + const Block& block = op->block; + for (size_t i = 0; i < offset; ++i) { + Var block_var = Downcast(op->iter_values[i]); + auto it = equal_map_.find(block_var); + equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); + } + } + } + + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order shoudl match. + // When checking iter vars, DefEqual is used to remap variables. + // Only the inner most several axis are compared. Other iter vars are added to extra_block_vars. + if (op->iter_vars.size() < rhs->iter_vars.size()) return false; + + size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); + for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { + auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; + // Skip iter dom + if (!DefEqual(lhs_var->var, rhs_var->var)) { + return false; + } + if (lhs_var->iter_type != rhs_var->iter_type) { + return false; + } + } + + if (is_scope_block) { + for (size_t i = 0; i < offset; ++i) { + extra_block_vars_.push_back(op->iter_vars[i]); + } + } + + if (!is_scope_block) { + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (!CompareType(op->dtype, rhs->dtype)) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + equal_map_[lhs] = rhs; + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret; + ret.reserve(map.size()); + for (const auto& pair : map) { + ret.emplace_back(pair); + } + sort(ret.begin(), ret.end()); + return ret; + }; + + auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + // Remap both buffer itself and buffer data + // Skip buffer shape + bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && + CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } else if (assert_mode_) { + LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs; + } + return equal; +} + +bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + // Only for block region declaration + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + // Number of indices in desc_block must be smaller than it in AST + if (rhs->region.size() > lhs->region.size()) return false; + + std::vector lhs_region; + for (const auto& range : lhs->region) { + lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); + } + size_t offset = lhs_region.size() - rhs->region.size(); + // initialize buffer indices + bool need_update = false; + if (!buffer_indices_.count(lhs->buffer)) { + need_update = true; + buffer_indices_[lhs->buffer] = std::vector(); + } else { + if (offset != buffer_indices_[lhs->buffer].size()) return false; + } + std::vector& indices = buffer_indices_[lhs->buffer]; + for (size_t i = 0; i < offset; ++i) { + const Range& range = lhs_region[i]; + // High-dim region must be element-wise + if (!is_one(range->extent)) return false; + if (need_update) { + indices.push_back(range->min); + } else { + // The order matters since we only map inner block_var to outside block_var + if (!VisitExpr(range->min, indices[i])) return false; + } + } + for (size_t i = 0; i < rhs->region.size(); ++i) { + if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; + } + return true; +} + +// Only for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + + if (rhs->indices.size() > lhs->indices.size()) return false; + // otherwise + size_t offset = lhs->indices.size() - rhs->indices.size(); + for (size_t i = 0; i < rhs->indices.size(); ++i) { + if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { + if (lhs == rhs) return true; + return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); +} + +// Deep comparison to check if two IR graph are equivalent +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool equal = (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Exprs are not matching between:" << n << " and " << other; + return equal; } bool IsSpatial(const StmtSRef& block_sref) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 7e694a938550..a8f815fec172 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -615,6 +615,20 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { return CreateRV(result); } +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), intrin); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + /******** Schedule: Annotation ********/ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e95a432985b6..dde811db4d1a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -124,6 +124,8 @@ class ConcreteScheduleNode : public ScheduleNode { void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv) override; + void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 09df62e5025b..1f3032d10b2a 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -379,8 +379,22 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer /******** Schedule: Blockize & Tensorize ********/ +/*! + * \brief Convert the subtree rooted at a specific loop into a block. + * \param self The state of the schedule + * \param loop_sref The root of the subtree + * \return The new block + */ TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin. + * \param self The state of the schedule + * \param loop_sref The loop to be tensorized. + * \param intrin The tensor intrinsic. + */ +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrin); + /******** Schedule: Annotation ********/ /*! * \brief Annotate a block/loop with a key value pair diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 37a8fe522546..29248ce171dc 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -23,6 +23,36 @@ namespace tvm { namespace tir { +/*! + * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace + * represented by the outer loops. + */ +class SubspaceNotDivisibleError : public ScheduleError { + public: + explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) + : mod_(std::move(mod)), + scope_loop_(std::move(scope_loop)), + inner_block_(std::move(inner_block)) {} + + String FastErrorString() const final { + return "ScheduleError: The bindings of the inner block can not be blockized."; + } + + String DetailRenderTemplate() const final { + return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " + "starting at {1}."; + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + + private: + IRModule mod_; + For scope_loop_; + Block inner_block_; +}; + /*! * \brief Detect if bindings are a trivial case of the subspace division where we can divide the * block iter bindings into two categories: @@ -53,6 +83,8 @@ Array> TrivialSubspaceDivision(const Array& iter for (const Var& var : inner_iters) { inner_loop_vars.insert(var.get()); } + const arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); + for (size_t i = 0; i < bindings.size(); ++i) { bool outer = UsesVar( bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); @@ -69,16 +101,16 @@ Array> TrivialSubspaceDivision(const Array& iter if (outer && !inner) { arith::IterMark outer{nullptr}; const auto& outer_iter = iter_mark; - arith::IterMark inner_iter(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer_iter, inner_iter})); + const auto& inner_iter = unit_iter_mark; + res.push_back({outer_iter, inner_iter}); } else if (inner && !outer) { + const auto& outer_iter = unit_iter_mark; const auto& inner_iter = iter_mark; - arith::IterMark outer_iter(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer_iter, inner_iter})); + res.push_back({outer_iter, inner_iter}); } else if (!outer && !inner) { - arith::IterMark outer_iter(arith::IterSumExpr({}, 0), 1); - arith::IterMark inner_iter(arith::IterSumExpr({}, 0), 1); - res.push_back(Array({outer_iter, inner_iter})); + const auto& outer_iter = unit_iter_mark; + const auto& inner_iter = unit_iter_mark; + res.push_back({outer_iter, inner_iter}); } else { return {}; } @@ -88,21 +120,13 @@ Array> TrivialSubspaceDivision(const Array& iter return res; } -class SubspaceNotDivisibleError : public ScheduleError {}; - /*! - * \brief Regenerate outer loops of a statement - * \param + * \brief Generate the blockized init block. + * \param block The original block with init. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \return The subtree of the init block and its outer loops. */ -Stmt RegenerateLoops(const std::vector& loops, Stmt body) { - for (const ForNode* loop : loops) { - ObjectPtr new_loop = make_object(*loop); - new_loop->body = std::move(body); - body = For(new_loop); - } - return body; -} - Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, const std::vector& inner_loops) { Array init_block_iters; @@ -151,10 +175,9 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r /*body=*/block->init.value(), // /*init=*/NullOpt}; Stmt new_init = BlockRealize( - /*iter_values=*/init_bindings, - /*predicate=*/inner_block_realize->predicate, - /*block=*/ std::move(init_block) - ); + /*iter_values=*/init_bindings, + /*predicate=*/inner_block_realize->predicate, + /*block=*/std::move(init_block)); // Step 5: Generate the parent loops for the init block for (const ForNode* init_loop : init_loops) { @@ -219,21 +242,24 @@ class LoopSubspaceCollector { /*! * \brief Check the bindings of the block iters can be divided by a subspace collected by the * collector. + * \param mod The current IR module. * \param block_realize The block realize to be checked. * \param collector The collector which has collected the loops of the block. * \param analyzer The arithmetic analyzer. * \return The result of the subspace division. * \throws ScheduleError If the bindings are not divisible by the subspace. */ -Array> CheckSubspaceDivisible(const BlockRealize& block_realize, +Array> CheckSubspaceDivisible(const IRModule& mod, + const BlockRealize& block_realize, const LoopSubspaceCollector& collector, arith::Analyzer* analyzer) { const Block& block = block_realize->block; + DiagnosticContext diag_ctx(DiagnosticContext::Default(mod)); Array> division = arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, block_realize->predicate, - /*require_bijective=*/false, analyzer); + /*require_bijective=*/false, analyzer, diag_ctx); if (division.empty()) { // If we can't do perfect subspace division, check if it is a trivial case of subspace division. @@ -242,13 +268,23 @@ Array> CheckSubspaceDivisible(const BlockRealize& block_r collector.outer_loop_vars, collector.inner_loop_vars, block_realize->predicate); } - // TODO: raise schedule error - CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + if (division.empty()) { + throw SubspaceNotDivisibleError(mod, GetRef(collector.inner_loops.back()), block); + } return division; } +/*! + * \brief The binding extractor to compute the bindings of the outer and the inner blocks after + * blockize. + */ class BlockizedBindingExtractor { public: + /*! + * \brief Extract bindings for blockize. + * \param iter_vars The iter vars of the original inner block. + * \param division The result of the subspace division. + */ void ExtractBindings(const Array& iter_vars, const Array>& division) { ICHECK(iter_vars.size() + 1 == division.size()); @@ -278,8 +314,6 @@ class BlockizedBindingExtractor { outer_bindings.push_back( arith::NormalizeIterMapToExpr(GetRef(outer_binding))); outer_iter_vars.push_back(outer_var); - // generate a new iter var for outer block - // TODO: add test case outer extent is zero PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; if (const auto* op = division[i][1]->source.as()) { base = base + op->base; @@ -290,9 +324,8 @@ class BlockizedBindingExtractor { base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); } inner_iter_vars.push_back(iter_var); - // bv_iter: inner block iter -> division inner extent inner_iter_relaxed_range.Set(iter_var->var, - Range::FromMinExtent(base, division[i][1]->extent)); + arith::IntSet::FromMinExtent(base, division[i][1]->extent)); } } } @@ -306,34 +339,43 @@ class BlockizedBindingExtractor { Array inner_bindings; /*! \brief The range of the inner block iters Note that this is different from the domain of the - * inner block iters. */ - Map inner_iter_relaxed_range; + * inner block iters. + */ + Map inner_iter_relaxed_range; }; - /*! - * \brief + * \brief Compute the access region of the outer block by relaxing the inner loops. + * \param buffer_region The original buffer region. + * \param The range of the inner loops. + * \param analyzer The arithmetic analyzer. + * \return The new buffer region. */ BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, - const Map& inner_iter_relaxed_range, + const Map& inner_iter_relaxed_range, arith::Analyzer* analyzer) { Array new_region; new_region.reserve(buffer_region->region.size()); - for (const auto& range : buffer_region->region) { - const Array& res = - arith::DetectIterMap({range->min}, inner_iter_relaxed_range, true, false, analyzer); - ICHECK_EQ(res.size(), 1); - const arith::IterSumExpr& normalized_expr = res[0]; - PrimExpr extent = 1; - if (normalized_expr->args.size() == 1) { - ICHECK(analyzer->CanProve(normalized_expr->args[0]->scale - range->extent == 0)); - extent = normalized_expr->args[0]->extent; - } - new_region.push_back(Range::FromMinExtent(normalized_expr->base, extent * range->extent)); + Array relaxed_int_set = + arith::EvalSet(buffer_region->region, inner_iter_relaxed_range); + ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size()); + for (size_t i = 0; i < buffer_region->region.size(); i++) { + Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]); + new_region.push_back(relaxed_int_set[i].CoverRange(max_range)); } return BufferRegion(buffer_region->buffer, std::move(new_region)); }; +/*! + * \brief Generate the outer block after blockize. + * \param extractor The binding extractor which has extracted the blockized bindings. + * \param block The original inner block. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \param predicate The outer predicate of the subspace division. + * \param analyzer The arithmetic analyzer. + * \return The block realize of the outer block after blockize. + */ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, const Block& block, BlockRealize inner_block_realize, const std::vector& inner_loops, @@ -354,7 +396,16 @@ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extrac new_reads.MutateByApply(f_mutate); new_writes.MutateByApply(f_mutate); - Stmt outer_block_body = RegenerateLoops(inner_loops, inner_block_realize); + // Step 3: Generate the body of the outer block. The body of the outer block is the inner block + // realize and its surounding loops. + Stmt outer_block_body = inner_block_realize; + for (const ForNode* loop : inner_loops) { + ObjectPtr new_loop = make_object(*loop); + new_loop->body = std::move(outer_block_body); + outer_block_body = For(new_loop); + } + + // Step 4: Generate the outer block and block realize. Block outer_block{/*iter_vars=*/extractor.outer_iter_vars, // /*reads=*/new_reads, // /*writes=*/new_writes, // @@ -368,14 +419,6 @@ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extrac } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { - /*! - * Check: - * - The sub AST is one-line with only one block - * - * Mutate: - * - extra block var from the only block - * - Update block binding - */ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); arith::Analyzer analyzer; @@ -390,7 +433,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { // Step 3: Calculate subspace division for the inner loops. Array> division = - CheckSubspaceDivisible(block_realize, collector, &analyzer); + CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer); // Step 4: Generate bindings for the outer block and the inner block based on the result of // the subspace division. @@ -416,15 +459,130 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { // Step 8: Update the cached flags const StmtSRef& outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); - BlockInfo& outer_block_info = self->block_info[outer_block_sref]; - const BlockInfo& inner_block_info = self->block_info.at(block_sref); - outer_block_info.affine_binding = inner_block_info.affine_binding; - outer_block_info.region_cover = inner_block_info.region_cover; - outer_block_info.scope->stage_pipeline = inner_block_info.scope->stage_pipeline; - + StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + BlockInfo old_block_info = self->GetBlockInfo(scope_root); + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + // 'affine_binding' depends on the outer loops and are not changed. + self->block_info[scope_root].affine_binding = old_block_info.affine_binding; return outer_block_sref; } +/*! + * \brief Update the map from the buffers in the description to the implementation of the tensor + * intrinsic. \param intrinsic The tensor intrinsic. \param buffer_map The map to be updated. + */ +void RemapTensorIntrinBuffers( + const TensorIntrin& intrinsic, + std::unordered_map* buffer_map) { + ICHECK_EQ(intrinsic->description->params.size(), intrinsic->implementation->params.size()); + for (size_t i = 0; i < intrinsic->description->params.size(); ++i) { + const auto& lhs_var = intrinsic->description->params[i]; + const auto& lhs_buffer = intrinsic->description->buffer_map[lhs_var]; + const auto& rhs_var = intrinsic->implementation->params[i]; + const auto& rhs_buffer = intrinsic->implementation->buffer_map[rhs_var]; + (*buffer_map)[rhs_buffer] = lhs_buffer; + } +} + +void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrinsic) { + /*! + * Check: + * - Check buffer binding, including type, alignment, shape and etc. + * - Check the sub AST is equal to the description function. + * + * Mutate: + * - Blockize the sub AST (please refer blockize for details) + * - Bind buffers + * - Mutate the implementation of the tensor intrinsic by replacing its buffers with new + * buffers created via match buffer region. + * - Replace the sub tree with the mutated function. + */ + const auto* loop = loop_sref->StmtAs(); + CHECK(loop) << "Only support tensorize a loop for now"; + + const auto* desc_block_realize = + Downcast(intrinsic->description->body)->block->body.as(); + const Block& desc_block = desc_block_realize->block; + const auto* impl_block_realize = + Downcast(intrinsic->implementation->body)->block->body.as(); + Block impl_block = impl_block_realize->block; + + // Step 1: Blockize the subtree rooted at the given loop + const StmtSRef& block_sref = Blockize(self, loop_sref); + const BlockRealize& block_realize = GetBlockRealize(self, block_sref); + + // Step 2: Compare the block with the description of the tensor intrinsic, find the correspondence + // between buffers in the block and the description. + TensorizeComparator comparator(/*assert_mode=*/true); + comparator.VisitStmt(block_realize, GetRef(desc_block_realize)); + + // Step 3: Find the correspondence between buffers in the current AST and the implementation of + // the tensor intrinsic + + // Step 3.1: Map from intrinsic func buffer to description func buffer + std::unordered_map intrin_buffer_map; + RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); + // Step 3.2: Map form intrinsic func buffer to current AST buffer + std::unordered_map buffer_map; + for (const auto& pair : intrin_buffer_map) { + auto it = comparator.rhs_buffer_map_.find(pair.second); + ICHECK(it != comparator.rhs_buffer_map_.end()); + buffer_map[pair.first] = it->second; + } + + // Step 4: Create MatchBufferRegion for the params of the implementation function of the tensor + // intrin to make them subregions of the buffer in the original IR. + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + for (const auto& read : impl_block->reads) { + buffer_region_map.emplace(read->buffer, read->region); + } + for (const auto& write : impl_block->writes) { + buffer_region_map.emplace(write->buffer, write->region); + } + Array match_buffer_regions; + for (size_t i = 0; i < intrinsic->implementation->params.size(); ++i) { + const auto& param = intrinsic->implementation->params[i]; + const auto& buffer = intrinsic->implementation->buffer_map.at(param); + const auto& source = buffer_map.at(buffer); + Region region = buffer_region_map.at(buffer); + auto extra_indices = comparator.buffer_indices_.at(source); + std::vector extra_buffer_ranges; + std::transform(extra_indices.begin(), extra_indices.end(), + std::back_inserter(extra_buffer_ranges), + [](const PrimExpr& index) { return Range::FromMinExtent(index, 1); }); + region.insert(region.begin(), extra_buffer_ranges.begin(), extra_buffer_ranges.end()); + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, region))); + } + + // Step 5: Replace the subtree in the original IR with the tensor intrin implementation. + ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); + new_block_ptr->body = impl_block->body; + ICHECK(new_block_ptr->match_buffers.empty()); + new_block_ptr->match_buffers = std::move(match_buffer_regions); + Block new_block(new_block_ptr); + + // Substitution map for the intrin block to get the correct bindings. + Map bv_map; + for (size_t i = 0; i < desc_block->iter_vars.size(); ++i) { + auto it = comparator.equal_map_.find(desc_block->iter_vars[i]->var); + if (it != comparator.equal_map_.end()) { + bv_map.Set(impl_block->iter_vars[i]->var, Downcast(it->second)); + } else { + bv_map.Set(impl_block->iter_vars[i]->var, Integer(0)); + } + } + new_block = Downcast(SubstituteInScope(new_block, bv_map)); + self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); + + // Step 6: Update the cached flags. + StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + self->UpdateScopeBlockInfo(static_cast(scope_root->stmt)->body); +} + +/******** InstructionKind Registration ********/ + struct BlockizeTraits : public UnpackedInstTraits { static constexpr const char* kName = "Blockize"; static constexpr bool kIsPure = false; @@ -446,10 +604,35 @@ struct BlockizeTraits : public UnpackedInstTraits { } template - friend struct UnpackedInstTraits; + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct TensorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Tensorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String intrin_name) { + return sch->Tensorize(loop_rv, intrin_name); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String intrin_name) { + PythonAPICall py("tensorize"); + py.Input("loop", loop_rv); + py.Input("intrin", intrin_name); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index f9a7c4bab614..2736375dc591 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -187,6 +187,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") /******** (FFI) Blockize & Tensorize ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") + .set_body_typed([](Schedule self, LoopRV loop_rv, ObjectRef intrin) { + if (const auto* str = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(str)); + } + if (const auto* p_intrin = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(p_intrin)); + } + LOG(FATAL) << "TypeError: Cannot handle type: " << intrin->GetTypeKey(); + throw; + }); + /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..b4efbcb960d1 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -201,9 +201,8 @@ class BlockInfoCollector : private StmtVisitor { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` Array child_block_srefs = std::move(block_frames_.back()); - BlockInfo& info = - self_->block_info.emplace(scope_root, BlockInfo(BlockScope(child_block_srefs))) - .first->second; + self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); + BlockInfo& info = self_->block_info[scope_root]; // Set `affine_binding` if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index fc356cf50484..f699963f7e11 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -367,6 +367,20 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { return new_block; } +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin_name); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{intrin_name}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) { + LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the tensor intrin and pass the intrin name instead."; +} + /******** Schedule: Annotation ********/ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 9940749d2f3e..1cba79a1ee47 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -88,6 +88,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv) final; + void Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d52628..0165df5949df 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,52 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +/******** IR Substitution ********/ +class IRSubstituteInScope : public StmtExprMutator { + public: + explicit IRSubstituteInScope(std::function fmap) + : fmap_(std::move(fmap)) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = fmap_(op); + if (it.defined()) { + return it; + } else { + return GetRef(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); }; + Array v = op->iter_values; + v.MutateByApply(fmutate); + PrimExpr pred = this->VisitExpr(op->predicate); + if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + n->predicate = std::move(analyzer.Simplify(pred)); + return Stmt(n); + } + } + + private: + const std::function fmap_; + arith::Analyzer analyzer; +}; + +Stmt SubstituteInScope(const Stmt& stmt, const Map& subst_map) { + auto fmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = subst_map.find(GetRef(v)); + if (it != subst_map.end()) { + return (*it).second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(std::move(fmap))(stmt); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 3932c4bdbd3d..b1ce52407baf 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -104,6 +104,14 @@ Array ReplaceBuffer(Array match_buffers, c void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt); +/******** IR Substitution ********/ + +/*! + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, const Map& subst_map); + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index c1fa7efbf827..84fd329899ac 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -1,10 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring import sys import pytest import tvm -from tvm import tir, te from tvm.script import tir as T +from tvm import tir from tvm.tir.schedule.testing import verify_trace_roundtrip +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @T.prim_func def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): @@ -15,12 +34,14 @@ def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128 @T.prim_func -def single_elementwise_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: +def single_elementwise_blockized1( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: with T.block("blockized_B"): vio = T.axis.spatial(1, 0) vjo = T.axis.spatial(1, 0) - T.reads(A[0 : 128, 0 : 128]) - T.writes(B[0 : 128, 0 : 128]) + T.reads(A[0:128, 0:128]) + T.writes(B[0:128, 0:128]) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -30,44 +51,45 @@ def single_elementwise_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer @T.prim_func -def two_elementwise(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - C = T.match_buffer(c, (128, 128)) - B = T.alloc_buffer((128, 128)) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * 2.0 - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, vj] + 1.0 +def single_elementwise_blockized2( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + for i in T.serial(128): + with T.block("blockized_B"): + vi = T.axis.spatial(128, i) + vjo = T.axis.spatial(1, 0) + T.reads(A[vi, 0:128]) + T.writes(B[vi, 0:128]) + for j in T.serial(128): + with T.block("B"): + vj = T.axis.spatial(128, j) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) @T.prim_func -def blockize(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, (128, 128), "float32") - A = T.match_buffer(a, (128, 128), "float32") - B = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(8, 8): - with T.block("blockized_B"): +def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(128, 128): + with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - for ii, jj in T.grid(16, 16): - with T.block("B"): - vii = T.axis.S(128, vi * 16 + ii) - vjj = T.axis.S(128, vj * 16 + jj) - B[vii, vjj] = A[vii, vjj] * T.float32(2) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func -def blockize_schedule_1(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, [128, 128]) - A = T.match_buffer(a, [128, 128]) - # body +def two_elementwise_blockized( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"] +) -> None: with T.block("root"): T.reads([]) T.writes([]) @@ -103,49 +125,7 @@ def blockize_schedule_1(a: T.handle, c: T.handle) -> None: @T.prim_func -def blockize_schedule_2(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, [128, 128]) - A = T.match_buffer(a, [128, 128]) - # body - with T.block("root"): - T.reads([]) - T.writes([]) - B = T.alloc_buffer([128, 128]) - for i0_outer in range(0, 4): - for i1_outer in range(0, 4): - for ax0 in range(0, 2): - for ax1 in range(0, 2): - with T.block("blockized_B"): - vio = T.axis.S(8, ((i0_outer * 2) + ax0)) - vjo = T.axis.S(8, ((i1_outer * 2) + ax1)) - T.reads( - [A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] - ) - T.writes( - [B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] - ) - for i0_inner in range(0, 16): - for i1_inner in range(0, 16): - with T.block("B"): - vi = T.axis.S(128, ((vio * 16) + i0_inner)) - vj = T.axis.S(128, ((vjo * 16) + i1_inner)) - T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) - T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) - B[vi, vj] = A[vi, vj] * T.float32(2) - for i0_inner_1 in range(0, 32): - for i1_inner_1 in range(0, 32): - with T.block("C"): - vi = T.axis.S(128, ((i0_outer * 32) + i0_inner_1)) - vj = T.axis.S(128, ((i1_outer * 32) + i1_inner_1)) - T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) - T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) - C[vi, vj] = B[vi, vj] + T.float32(1) - - -@T.prim_func -def rowsum(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128]) +def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: for k, i in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [k, i]) @@ -155,9 +135,7 @@ def rowsum(a: T.handle, b: T.handle) -> None: @T.prim_func -def rowsum_blockized(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128]) +def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: with T.block("blockized_B"): vko = T.axis.R(1, 0) vio = T.axis.S(1, 0) @@ -172,24 +150,34 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: B[vi] = B[vi] + A[vi, vk] -def test_blockize(): - func = elementwise +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +def test_blockize_outer(): + func = single_elementwise # schedule s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") - _ = s.get_block("C") x, y = s.get_loops(B) - xo, xi = s.split(x, factors=[None, 16]) - yo, yi = s.split(y, factors=[None, 16]) - s.reorder(xo, yo, xi, yi) - s.blockize(xi) - tvm.ir.assert_structural_equal(blockize, s.mod["main"]) + s.blockize(x) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_inner(): + func = single_elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + x, y = s.get_loops(B) + s.blockize(y) + print(s.mod["main"].script()) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2) verify_trace_roundtrip(sch=s, mod=func) -def test_blockize_schedule(): - func = elementwise - # test 1 +def test_two_elementwise_blockize_reverse_compute_at(): + func = two_elementwise s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") C = s.get_block("C") @@ -200,9 +188,12 @@ def test_blockize_schedule(): s.blockize(xi) s.reverse_compute_at(C, yo) s.blockize(s.get_loops(C)[-2]) - tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) verify_trace_roundtrip(sch=s, mod=func) - # test 2 + + +def test_two_elementwise_blockize_compute_at(): + func = two_elementwise s = tir.Schedule(func, debug_mask="all") B = s.get_block("B") C = s.get_block("C") @@ -213,33 +204,16 @@ def test_blockize_schedule(): s.blockize(xi) s.compute_at(B, yo) s.blockize(s.get_loops(B)[-2]) - tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) verify_trace_roundtrip(sch=s, mod=func) - # test 3 - s = tir.Schedule(func, debug_mask="all") - B = s.get_block("B") - C = s.get_block("C") - x, y = s.get_loops(B) - xo, xi = s.split(x, factors=[None, 16]) - yo, yi = s.split(y, factors=[None, 16]) - s.reorder(xo, yo, xi, yi) - b_outer = s.blockize(xi) - xC, yC = s.get_loops(C) - xCo, xCi = s.split(xC, factors=[None, 32]) - yCo, yCi = s.split(yC, factors=[None, 32]) - s.reorder(xCo, yCo, xCi, yCi) - s.compute_at(b_outer, yCo) - tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) - verify_trace_roundtrip(sch=s, mod=func) def test_blockize_init_loops(): s = tir.Schedule(rowsum, debug_mask="all") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) - print(s.mod['main'].script()) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) - # verify_trace_roundtrip(sch=s, mod=rowsum) + verify_trace_roundtrip(sch=s, mod=rowsum) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py new file mode 100644 index 000000000000..a9eb6860f9cd --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -0,0 +1,446 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@T.prim_func +def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads( + [ + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 16], + B[vj : vj + 16, vk : vk + 16], + ] + ) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, ()) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[()] = C[()] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), offset_factor=1) + B = T.match_buffer(b, (4,), offset_factor=1) + C = T.match_buffer(c, (), offset_factor=1) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads(C[()], A[v0 : v0 + 4], B[v0 : v0 + 4]) + T.writes(C[()]) + T.evaluate( + T.call_extern( + "vec4add", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1)) + B = T.match_buffer(b, (16, 1)) + C = T.match_buffer(c, (16, 16)) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(1, 0) + T.reads( + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 1], + B[vj : vj + 16, vk : vk + 1], + ) + T.writes(C[vi : vi + 16, vj : vj + 16]) + for i, j in T.grid(16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = C[vii, vjj] + A[vii, vk] * B[vjj, vk] + + +@T.prim_func +def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1)) + B = T.match_buffer(b, (16, 1)) + C = T.match_buffer(c, (16, 16)) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(1, 0) + T.reads( + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 1], + B[vj : vj + 16, vk : vk + 1], + ) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.call_extern( + "outer_product", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def matmul( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], +) -> None: + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + + for i_outer, j_outer in T.grid(8, 8): + for i_inner_init, j_inner_init in T.grid(16, 16): + with T.block("init"): + vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) + vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) + C[vi_init, vj_init] = T.float32(0) + for k_outer in T.grid(8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads( + [ + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ] + ) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + [16, 16], + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def batch_matmul( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def tensorized_batch_matmul_mma( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n in range(0, 16): + for i, j, k in T.grid(8, 8, 8): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + T.reads( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ) + T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_dot_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i, j, k_0 in T.grid(16, 128, 128, 32): + with T.block("blockized_update"): + vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0]) + T.reads( + C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko * 4 : vko * 4 + 4] + ) + T.writes(C[vn, vi, vj]) + A_1 = T.match_buffer( + A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + B_1 = T.match_buffer( + B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32", offset_factor=1) + T.evaluate( + T.call_extern( + "vec4add", + C_1.data, + C_1.elem_offset, + A_1.data, + A_1.elem_offset, + B_1.data, + B_1.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_outer_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i_0, j_0, k in T.grid(16, 8, 8, 128): + with T.block("blockized_update"): + vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k]) + T.reads( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vn, vio * 16 : vio * 16 + 16, vk], + B[vn, vjo * 16 : vjo * 16 + 16, vk], + ) + T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, 0], [16, 1], dtype="float32") + B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, 0], [16, 1], dtype="float32") + C_1 = T.match_buffer( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32" + ) + T.evaluate( + T.call_extern("outer_product", C_1.data, 0, A_1.data, 0, B_1.data, 0, dtype="int32") + ) + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) +tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin) +tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin) + + +def test_tensorize_matmul(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_batch_matmul(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + _, i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=batch_matmul) + + +def test_tensorize_dot_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, _, _, k = s.get_loops(C) + _, ki = s.split(k, factors=[None, 4]) + s.tensorize(ki, "test_dot_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_outer_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, i, j, k = s.get_loops(C) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + s.reorder(io, jo, k, ii, ji) + s.tensorize(ii, "test_outer_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From a22afce33926dd168ea3348a9969338e02f36d7a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 7 Jan 2022 17:37:28 -0500 Subject: [PATCH 05/20] add examples --- include/tvm/tir/schedule/schedule.h | 4 +- python/tvm/tir/schedule/schedule.py | 169 ++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 58d57222b439..502087e37571 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -481,13 +481,13 @@ class ScheduleNode : public runtime::Object { virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; /*! * \brief Tensorize the computation enclosed by loop with tensor_intrin - * \param loop_rv the loop/block to be tensorized + * \param loop_rv the loop to be tensorized * \param intrin the tensor intrinsic */ virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0; /*! * \brief Tensorize the computation enclosed by loop with tensor_intrin - * \param loop_rv The loop/block to be tensorized + * \param loop_rv The loop to be tensorized * \param intrin_name Name of the tensor intrinsic */ virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f0d33f2e7c26..3c1bd67d2c99 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1836,6 +1836,175 @@ def after_blockize( @type_checked def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin]) -> None: + """Tensorize the computation enclosed by loop with the tensor intrinsic. + + Parameters + ---------- + loop : LoopRV + The loop to be tensorized. + tensor_intrin : Union[str, TensorIntrin] + The tensor intrin or the name of the tensor intrin. + + Examples + -------- + + Before tensorize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_tensorize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0 in T.grid(8, 8): + for i_1_init, j_1_init in T.grid(16, 16): + with T.block("init"): + vi = T.axis.spatial(128, i_0 * 16 + i_1_init) + vj = T.axis.spatial(128, j_0 * 16 + j_1_init) + T.reads() + T.writes(C[vi, vj]) + C[vi, vj] = T.float32(0) + for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16): + with T.block("update"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + vk = T.axis.reduce(128, k_0 * 16 + k_1) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + Declare and register the tensor intrinsic: + + .. code-block:: python + + @T.prim_func + def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + + @T.prim_func + def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads( + [ + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 16], + B[vj : vj + 16, vk : vk + 16], + ] + ) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) + + Create the schedule and do tensorize: + + .. code-block:: python + + sch = tir.Schedule(before_tensorize) + update = s.get_block("update") + _, _, _, i1, _, _ = s.get_loops(update) + s.tensorize(ii, "test_mma_intrin") + print(sch.mod["main"].script()) + + After applying tensorize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_tensoirze( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0 in T.grid(8, 8): + for i_1_init, j_1_init in T.grid(16, 16): + with T.block("init"): + vi = T.axis.spatial(128, i_0 * 16 + i_1_init) + vj = T.axis.spatial(128, j_0 * 16 + j_1_init) + T.reads() + T.writes(C[vi, vj]) + C[vi, vj] = T.float32(0) + for k_0 in T.serial(8): + with T.block("blockized_update"): + vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) + T.reads( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + ) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer( + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + B_1 = T.match_buffer( + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + C_1 = T.match_buffer( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + T.evaluate( + T.tvm_mma_sync( + C_1.data, + C_1.elem_offset // 256, + A_1.data, + A_1.elem_offset // 256, + B_1.data, + B_1.elem_offset // 256, + C_1.data, + C_1.elem_offset // 256, + dtype="handle", + ) + ) + + """ if isinstance(tensor_intrin, str): tensor_intrin = String(tensor_intrin) _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member From 207bfb3896cf095875361db9349d6a0c9384f830 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 7 Jan 2022 17:41:58 -0500 Subject: [PATCH 06/20] lint --- python/tvm/tir/function.py | 4 +--- python/tvm/tir/schedule/schedule.py | 19 ++++++++++--------- src/tir/ir/function.cc | 7 ++++--- src/tir/schedule/analysis.h | 9 +++++++-- src/tir/schedule/analysis/analysis.cc | 2 +- .../schedule/primitive/blockize_tensorize.cc | 2 +- src/tir/schedule/traced_schedule.cc | 3 ++- 7 files changed, 26 insertions(+), 20 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ed693a9f3c49..a372ae09ce97 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -182,9 +182,7 @@ def __init__(self, desc_func, intrin_func): @staticmethod def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc): - return _ffi_api.TensorIntrinRegister( # type: ignore - name, desc_func, intrin_func - ) + return _ffi_api.TensorIntrinRegister(name, desc_func, intrin_func) # type: ignore @staticmethod def get(name: str): diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 3c1bd67d2c99..8adf5c692462 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1792,7 +1792,7 @@ def before_blockize( T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) - + Create the schedule and do set_scope: .. code-block:: python @@ -1829,9 +1829,9 @@ def after_blockize( ---- blockize requires there is exactly one block under the given loop and the bindings of the block are divisible by the subspace represented by the loops starting at the given loop. - + """ - + return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member @type_checked @@ -1886,7 +1886,7 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) - + with T.block("root"): vi = T.axis.S(16, 0) vj = T.axis.S(16, 0) @@ -1897,14 +1897,14 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: vjj = T.axis.S(16, vj + j) vkk = T.axis.R(16, vk + k) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] - - + + @T.prim_func def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) - + with T.block("root"): vi = T.axis.S(16, 0) vj = T.axis.S(16, 0) @@ -2003,12 +2003,13 @@ def after_tensoirze( dtype="handle", ) ) - + """ if isinstance(tensor_intrin, str): tensor_intrin = String(tensor_intrin) _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member - self, loop, tensor_intrin) + self, loop, tensor_intrin + ) ########## Schedule: Annotation ########## diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 8d99c864fa49..83e201eea08b 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -80,8 +80,10 @@ TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) { CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size()); // check both functions' bodies are directly block - const auto* desc_realize = Downcast(desc_func->body)->block->body.as(); - const auto* intrin_realize = Downcast(intrin_func->body)->block->body.as(); + const auto* desc_realize = + Downcast(desc_func->body)->block->body.as(); + const auto* intrin_realize = + Downcast(intrin_func->body)->block->body.as(); CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block"; @@ -144,7 +146,6 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); - TVM_REGISTER_GLOBAL("tir.TensorIntrin") .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { return TensorIntrin(desc_func, intrin_func); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 2971493363db..59e6a2fef2db 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -563,9 +563,13 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // using ExprComparator = ExprFunctor; using StmtComparator = StmtFunctor; -/* \brief Deep comparison to check if two IR ASTs are equivalent */ +/*! \brief Deep comparison to check if two IR ASTs are equivalent */ class TensorizeComparator : public ExprComparator, public StmtComparator { public: + /*! + * \brief Constructor of TensorizeComparator + * \param assert_mode Whether to raise an error if the two IR ASTs do not match. + */ explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {} // Map from rhs buffer to lhs buffer @@ -623,7 +627,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { protected: bool assert_mode_; - bool is_scope_block = true, is_inner_block = true; + bool is_scope_block = true; + bool is_inner_block = true; }; } // namespace tir diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 918caa5091a7..abbeb9ad7576 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1659,7 +1659,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } catch (...) { throw InvalidStorageScopeError(self->mod, std::move(storage_scope)); } -}; +} /******** Tensorize Comparator ********/ diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 29248ce171dc..1d536c0f1bb5 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -364,7 +364,7 @@ BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, new_region.push_back(relaxed_int_set[i].CoverRange(max_range)); } return BufferRegion(buffer_region->buffer, std::move(new_region)); -}; +} /*! * \brief Generate the outer block after blockize. diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index f699963f7e11..b888833e2385 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -378,7 +378,8 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_n } void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) { - LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the tensor intrin and pass the intrin name instead."; + LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the " + "tensor intrin and pass the intrin name instead."; } /******** Schedule: Annotation ********/ From ae3d62777f4868d555204cb62486b4ce13cc74b0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 10 Jan 2022 14:41:38 -0500 Subject: [PATCH 07/20] Amend co-authors information Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou From 44c28171cf530d3a3ca06f6a7d1e9520374c2345 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 14 Jan 2022 13:15:07 -0500 Subject: [PATCH 08/20] WIP --- include/tvm/tir/function.h | 24 +- include/tvm/tir/schedule/schedule.h | 16 +- python/tvm/tir/function.py | 39 +- python/tvm/tir/schedule/schedule.py | 10 +- src/tir/ir/function.cc | 49 +-- src/tir/schedule/analysis.h | 73 ---- src/tir/schedule/analysis/analysis.cc | 330 +-------------- src/tir/schedule/concrete_schedule.cc | 8 +- src/tir/schedule/concrete_schedule.h | 4 +- src/tir/schedule/ir_comparator.cc | 380 ++++++++++++++++++ src/tir/schedule/ir_comparator.h | 107 +++++ src/tir/schedule/primitive.h | 6 +- .../schedule/primitive/blockize_tensorize.cc | 77 ++-- src/tir/schedule/schedule.cc | 16 +- src/tir/schedule/state.cc | 3 +- src/tir/schedule/traced_schedule.cc | 17 +- src/tir/schedule/traced_schedule.h | 4 +- 17 files changed, 646 insertions(+), 517 deletions(-) create mode 100644 src/tir/schedule/ir_comparator.cc create mode 100644 src/tir/schedule/ir_comparator.h diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 3f59066ffa6b..a85d19a1be76 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -193,13 +193,13 @@ class LinkedParam : public ObjectRef { class TensorIntrinNode : public Object { public: /*! \brief The function to describe the computation. */ - PrimFunc description; - /*! \brief The intrinsic function for lower-level implementation. */ - PrimFunc implementation; + PrimFunc desc; + /*! \brief The function of the implementation for the execution. */ + PrimFunc impl; void VisitAttrs(AttrVisitor* v) { - v->Visit("description", &description); - v->Visit("implementation", &implementation); + v->Visit("desc", &desc); + v->Visit("impl", &impl); } static constexpr const char* _type_key = "tir.TensorIntrin"; @@ -213,8 +213,8 @@ class TensorIntrin : public ObjectRef { public: /*! * \brief Constructor - * \param desc_func The function to describe the computation. - * \param intrin_func The intrinsic function for lower-level implementation. + * \param desc The function to describe the computation. + * \param impl The function of the implementation for the execution. */ TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); @@ -222,15 +222,17 @@ class TensorIntrin : public ObjectRef { * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked * up with its name. * \param name The name of the TensorIntrin to register - * \param desc_func The function to describe the computation. - * \param intrin_func The intrinsic function for lower-level implementation. - * \return The created TensorIntrin. + * \param intrin The TensorIntrin to register. + * \throws This method throws an exception if the TensorIntrin with the specified name already + * exists. */ - TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); + TVM_DLL static void Register(String name, TensorIntrin intrin); /*! * \brief Look up TensorIntrin by name. Raises an exception if not found. * \param name The name of the TensorIntrin. + * \return The TensorIntrin with the specified name. + * \throws This method throws an exception if the TensorIntrin does not exist. */ TVM_DLL static TensorIntrin Get(String name); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 502087e37571..be06b44820cd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -480,17 +480,17 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; /*! - * \brief Tensorize the computation enclosed by loop with tensor_intrin - * \param loop_rv the loop to be tensorized - * \param intrin the tensor intrinsic + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param loop_rv The loop to be tensorized + * \param intrin Name of the tensor intrinsic */ - virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0; + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0; /*! - * \brief Tensorize the computation enclosed by loop with tensor_intrin - * \param loop_rv The loop to be tensorized - * \param intrin_name Name of the tensor intrinsic + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param block_rv The block to be tensorized + * \param intrin Name of the tensor intrinsic */ - virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0; + virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0; /******** Schedule: Annotation ********/ /*! diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index a372ae09ce97..bcebab9ddc0a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -170,20 +170,43 @@ class TensorIntrin(Object): Parameters ---------- - desc_func: PrimFunc - The function to describe the computation + desc : PrimFunc + The function to describe the computation. - intrin_func: PrimFunc - The function for execution + impl : PrimFunc + The function of the implementation for the execution. """ - def __init__(self, desc_func, intrin_func): - self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc_func, intrin_func) + def __init__(self, desc, impl): + self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl) @staticmethod - def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc): - return _ffi_api.TensorIntrinRegister(name, desc_func, intrin_func) # type: ignore + def register(name: str, desc: PrimFunc, impl: PrimFunc): + """Register a tensor intrinsic with its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to register. + desc : PrimFunc + The function to describe the computation. + impl : PrimFunc + The function of the implementation for the execution. + """ + return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl)) # type: ignore @staticmethod def get(name: str): + """Look up a tensor intrinsic by its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to look up. + + Returns + ------- + result : TensorIntrin + The TensorIntrin with the specified name. + """ return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 8adf5c692462..c7caf6737549 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1835,14 +1835,14 @@ def after_blockize( return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member @type_checked - def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin]) -> None: + def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None: """Tensorize the computation enclosed by loop with the tensor intrinsic. Parameters ---------- - loop : LoopRV + block_or_loop : Union[BlockRV, LoopRV] The loop to be tensorized. - tensor_intrin : Union[str, TensorIntrin] + tensor_intrin : str The tensor intrin or the name of the tensor intrin. Examples @@ -2005,10 +2005,8 @@ def after_tensoirze( ) """ - if isinstance(tensor_intrin, str): - tensor_intrin = String(tensor_intrin) _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member - self, loop, tensor_intrin + self, block_or_loop, tensor_intrin ) ########## Schedule: Annotation ########## diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 83e201eea08b..3b2b4057aff4 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -74,51 +74,54 @@ class TensorIntrinManager { } }; -TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) { - // check the number of func var is equal - CHECK_EQ(desc_func->params.size(), intrin_func->params.size()); - CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size()); +TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { + // Check the number of func var is equal + CHECK_EQ(desc->params.size(), impl->params.size()) << "ValueError: The number of parameters of the description and the implementation of the tensor intrinsic doesn't match."; + for (size_t i = 0; i < desc->params.size(); i++) { + CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the tensor intrinsic should be handle only."; + CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of the tensor intrinsic should be handle only."; + } + ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); // check both functions' bodies are directly block const auto* desc_realize = - Downcast(desc_func->body)->block->body.as(); - const auto* intrin_realize = - Downcast(intrin_func->body)->block->body.as(); + Downcast(desc->body)->block->body.as(); + const auto* impl_realize = + Downcast(impl->body)->block->body.as(); CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; - CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block"; + CHECK(impl_realize != nullptr) << "intrinsic function's body expect a directly block"; const Block& desc_block = desc_realize->block; - const Block& intrin_block = intrin_realize->block; + const Block& impl_block = impl_realize->block; // check block var number and iter type - CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size()) - << "Two blocks should have the same number of block vars"; + CHECK_EQ(desc_block->iter_vars.size(), impl_block->iter_vars.size()) + << "ValueError: The blocks in the description and the implementation should have the same number of block vars"; for (size_t i = 0; i < desc_block->iter_vars.size(); i++) { const IterVar& desc_var = desc_block->iter_vars[i]; - const IterVar& intrin_var = intrin_block->iter_vars[i]; - CHECK(desc_var->iter_type == intrin_var->iter_type) - << "Block iter_type mismatch between " << desc_var->iter_type << " and " - << intrin_var->iter_type; + const IterVar& impl_var = impl_block->iter_vars[i]; + CHECK(desc_var->iter_type == impl_var->iter_type) + << "Block iter_type mismatch between " << IterVarType2String(desc_var->iter_type) << " and " + << IterVarType2String(impl_var->iter_type); } - auto n = make_object(); - n->description = std::move(desc_func); - n->implementation = std::move(intrin_func); + ObjectPtr n = make_object(); + n->desc = std::move(desc); + n->impl = std::move(impl); data_ = std::move(n); } -TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc intrin_func) { +void TensorIntrin::Register(String name, TensorIntrin intrin) { TensorIntrinManager* manager = TensorIntrinManager::Global(); - ICHECK_EQ(manager->reg.count(name), 0) + CHECK_EQ(manager->reg.count(name), 0) << "ValueError: TensorIntrin '" << name << "' has already been registered"; - TensorIntrin intrin(desc_func, intrin_func); manager->reg.Set(name, intrin); - return intrin; } TensorIntrin TensorIntrin::Get(String name) { const TensorIntrinManager* manager = TensorIntrinManager::Global(); - ICHECK_EQ(manager->reg.count(name), 1) + auto it = manager->reg.find(name); + CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name << "' is not registered"; return manager->reg.at(name); } diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 59e6a2fef2db..636cc7d0a5db 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -558,79 +558,6 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // int64_t max_parallel_extent, // int64_t max_parallel_basic); -/******** Tensorization Related ********/ - -using ExprComparator = ExprFunctor; -using StmtComparator = StmtFunctor; - -/*! \brief Deep comparison to check if two IR ASTs are equivalent */ -class TensorizeComparator : public ExprComparator, public StmtComparator { - public: - /*! - * \brief Constructor of TensorizeComparator - * \param assert_mode Whether to raise an error if the two IR ASTs do not match. - */ - explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {} - - // Map from rhs buffer to lhs buffer - std::unordered_map rhs_buffer_map_; - // Buffer indices mapping - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; - std::vector extra_block_vars_; - // variable remap if any - std::unordered_map equal_map_; - - bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; - bool VisitStmt(const Stmt& n, const Stmt& other) override; - - bool VisitStmt_(const ForNode* op, const Stmt& other) override; - bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; - bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; - bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; - bool VisitStmt_(const BlockNode* op, const Stmt& other) override; - - bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; - bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; - bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; - bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; - bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; - bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; - bool VisitExpr_(const NENode* op, const PrimExpr& other) override; - bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; - bool VisitExpr_(const LENode* op, const PrimExpr& other) override; - bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; - bool VisitExpr_(const GENode* op, const PrimExpr& other) override; - bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; - bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; - bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; - bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; - bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; - bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; - bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; - bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; - bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; - bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; - bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; - - bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); - virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); - bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); - bool CompareAnnotation(const std::pair& lhs, - const std::pair& rhs); - bool CompareAnnotationMap(const Map& lhs, const Map& rhs); - template - bool CompareBufferAccess(const T* lhs, const T* rhs); - template - bool CompareArray(const Array& lhs, const Array& rhs, F cmp); - bool CompareRange(const Range& lhs, const Range& rhs); - bool CompareType(const DataType& lhs, const DataType& rhs); - - protected: - bool assert_mode_; - bool is_scope_block = true; - bool is_inner_block = true; -}; - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index abbeb9ad7576..ddb4d3dd0a1c 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1661,334 +1661,6 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } } -/******** Tensorize Comparator ********/ - -bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { - if (n.same_as(other)) return true; - if (n->type_index() != other->type_index()) return false; - bool equal = StmtComparator::VisitStmt(n, other); - if (!equal && assert_mode_) - LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << other; - return equal; -} - -bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { - const auto* rhs = other.as(); - if (!DefEqual(op->loop_var, rhs->loop_var)) return false; - if (!VisitExpr(op->min, rhs->min)) return false; - if (!VisitExpr(op->extent, rhs->extent)) return false; - if (!VisitStmt(op->body, rhs->body)) return false; - if (op->kind != rhs->kind) return false; - if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return false; - if (op->thread_binding.defined() && - !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) - return false; - return CompareAnnotationMap(op->annotations, rhs->annotations); -} - -bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { - const auto* rhs = other.as(); - return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); -} - -bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { - const auto* rhs = other.as(); - return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); -} - -bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { - const auto* rhs = other.as(); - // Skip Compare binding values if the block is scope block (the outermost one). - if (!is_scope_block) { - size_t offset = op->iter_values.size() - rhs->iter_values.size(); - if (rhs->iter_values.size() > op->iter_values.size()) return false; - if (is_inner_block) { - // weak pattern matching for the inner block (the son of the scope block) - // where the pattern is v + iter <=> expr + iter - for (size_t i = 0; i < rhs->iter_values.size(); ++i) { - PrimExpr lhs_expr, rhs_expr; - Optional lhs_iter, rhs_iter; - auto detect = [](const PrimExpr& binding) -> std::pair> { - arith::PVar expr; - arith::PVar iter; - if (iter.Match(binding)) { - return std::make_pair(0, iter.Eval()); - } else if ((expr + iter).Match(binding)) { - return std::make_pair(expr.Eval(), iter.Eval()); - } else if ((iter + expr).Match(binding)) { - return std::make_pair(expr.Eval(), iter.Eval()); - } else { - return std::make_pair(expr.Eval(), NullOpt); - } - }; - std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); - std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); - CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; - if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); - if (is_zero(rhs_expr)) { - CHECK(is_zero(lhs_expr)) << "Incompatible binding"; - } else { - const auto* bv = rhs_expr.as(); - if (!bv) { - VisitExpr(lhs_expr, rhs_expr); - } else { - auto it = equal_map_.find(GetRef(bv)); - if (it == equal_map_.end()) { - equal_map_[GetRef(bv)] = lhs_expr; - } else { - CHECK(it->second->IsInstance()); - VisitExpr(lhs_expr, Downcast(it->second)); - } - } - } - } - } else { - for (size_t i = 0; i < rhs->iter_values.size(); ++i) { - if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; - } - const Block& block = op->block; - for (size_t i = 0; i < offset; ++i) { - Var block_var = Downcast(op->iter_values[i]); - auto it = equal_map_.find(block_var); - equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); - } - } - } - - return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); -} - -bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { - const auto* rhs = other.as(); - // Check block equality. - // All iter vars and buffer regions including the order shoudl match. - // When checking iter vars, DefEqual is used to remap variables. - // Only the inner most several axis are compared. Other iter vars are added to extra_block_vars. - if (op->iter_vars.size() < rhs->iter_vars.size()) return false; - - size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); - for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { - auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; - // Skip iter dom - if (!DefEqual(lhs_var->var, rhs_var->var)) { - return false; - } - if (lhs_var->iter_type != rhs_var->iter_type) { - return false; - } - } - - if (is_scope_block) { - for (size_t i = 0; i < offset; ++i) { - extra_block_vars_.push_back(op->iter_vars[i]); - } - } - - if (!is_scope_block) { - if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { - return false; - } - if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { - return false; - } - if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { - return false; - } - if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { - return false; - } - } - is_scope_block = false; - return VisitStmt(op->body, rhs->body); -} - -// Exprs -#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ - bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ - const auto* rhs = other.as(); \ - return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ - } - -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); -TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); - -bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; -} - -bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; -} - -bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); -} - -bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - auto lhs = GetRef(op); - if (lhs.same_as(other)) return true; - if (!CompareType(op->dtype, rhs->dtype)) return false; - auto it = equal_map_.find(lhs); - return it != equal_map_.end() && it->second.same_as(other); -} - -bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { - const auto* rhs = other.as(); - return CompareBufferAccess(op, rhs); -} - -bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { - if (lhs.same_as(rhs)) return true; - if (lhs->type_index() != rhs->type_index()) return false; - auto it = equal_map_.find(lhs); - // If there is already a mapping - if (it != equal_map_.end()) return it->second.same_as(rhs); - equal_map_[lhs] = rhs; - return true; -} - -bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, - const std::pair& rhs) { - if (lhs.first != rhs.first) return false; - if (!lhs.second.same_as(rhs.second)) return false; - return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); -} - -bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, - const Map& rhs) { - if (lhs.same_as(rhs)) return true; - if (lhs.size() != rhs.size()) return false; - - auto sort_map = - [](const Map& map) -> std::vector> { - std::vector> ret; - ret.reserve(map.size()); - for (const auto& pair : map) { - ret.emplace_back(pair); - } - sort(ret.begin(), ret.end()); - return ret; - }; - - auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); - - for (size_t i = 0; i < lhs.size(); ++i) { - if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; - } - return true; -} - -bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { - if (lhs.same_as(rhs)) return true; - // Remap both buffer itself and buffer data - // Skip buffer shape - bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && - CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); - if (equal) { - rhs_buffer_map_[rhs] = lhs; - } else if (assert_mode_) { - LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs; - } - return equal; -} - -bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { - // Only for block region declaration - if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; - // Number of indices in desc_block must be smaller than it in AST - if (rhs->region.size() > lhs->region.size()) return false; - - std::vector lhs_region; - for (const auto& range : lhs->region) { - lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); - } - size_t offset = lhs_region.size() - rhs->region.size(); - // initialize buffer indices - bool need_update = false; - if (!buffer_indices_.count(lhs->buffer)) { - need_update = true; - buffer_indices_[lhs->buffer] = std::vector(); - } else { - if (offset != buffer_indices_[lhs->buffer].size()) return false; - } - std::vector& indices = buffer_indices_[lhs->buffer]; - for (size_t i = 0; i < offset; ++i) { - const Range& range = lhs_region[i]; - // High-dim region must be element-wise - if (!is_one(range->extent)) return false; - if (need_update) { - indices.push_back(range->min); - } else { - // The order matters since we only map inner block_var to outside block_var - if (!VisitExpr(range->min, indices[i])) return false; - } - } - for (size_t i = 0; i < rhs->region.size(); ++i) { - if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; - } - return true; -} - -// Only for BufferStoreNode and BufferLoadNode -template -bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { - if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; - - if (rhs->indices.size() > lhs->indices.size()) return false; - // otherwise - size_t offset = lhs->indices.size() - rhs->indices.size(); - for (size_t i = 0; i < rhs->indices.size(); ++i) { - if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; - } - return true; -} - -template -bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { - if (lhs.same_as(rhs)) return true; - if (lhs.size() != rhs.size()) return false; - for (size_t i = 0; i < lhs.size(); ++i) { - if (!(this->*cmp)(lhs[i], rhs[i])) return false; - } - return true; -} - -bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { - return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); -} - -bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { - if (lhs == rhs) return true; - return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); -} - -// Deep comparison to check if two IR graph are equivalent -bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { - bool equal = (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); - if (!equal && assert_mode_) - LOG(FATAL) << "Exprs are not matching between:" << n << " and " << other; - return equal; -} - bool IsSpatial(const StmtSRef& block_sref) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { @@ -2176,5 +1848,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +======= +>>>>>>> cde8c476f (WIP) } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index a8f815fec172..fc63f305ff5e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -615,16 +615,16 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { return CreateRV(result); } -void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) { +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(loop_rv), intrin); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } -void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); + tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin)); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index dde811db4d1a..5f108178a83b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -124,8 +124,8 @@ class ConcreteScheduleNode : public ScheduleNode { void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv) override; - void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override; - void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override; + void Tensorize(const BlockRV& loop_rv, const String& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc new file mode 100644 index 000000000000..35c9e7098a31 --- /dev/null +++ b/src/tir/schedule/ir_comparator.cc @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./ir_comparator.h" + +namespace tvm { + +namespace tir { + +/******** Tensorize Comparator ********/ + +class TensorIntrinMismatchError : public ScheduleError { + public: + explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt) : + lhs_mod_(std::move(lhs_mod)), lhs_stmt_(std::move(lhs_stmt)), rhs_stmt_(std::move(rhs_stmt)) { + ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); + } + + String FastErrorString() const final { + return "ScheduleError: The stmt doesn't match the tensor intrin."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_; + return os.str(); + } + + IRModule mod() const final { return lhs_mod_; } + + Array LocationsOfInterest() const final { return {lhs_stmt_}; } + + private: + IRModule lhs_mod_; + Stmt lhs_stmt_; + Stmt rhs_stmt_; +}; + +/* Override the dispatcher to make sure RHS is always valid */ +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + bool equal = (n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other); + if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { + throw TensorIntrinMismatchError(lhs_mod_, n, other); + } + return equal; +} + +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + return (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (op->thread_binding.defined() != rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) { + return false; + } + if (op->kind != rhs->kind) return false; + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false; + return VisitStmt(op->body, rhs->body); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Skip Compare binding values if the block is scope block (the outermost one). + if (!is_scope_block) { + size_t offset = op->iter_values.size() - rhs->iter_values.size(); + if (rhs->iter_values.size() > op->iter_values.size()) return false; + if (is_inner_block) { + // weak pattern matching for the inner block (the son of the scope block) + // where the pattern is v + iter <=> expr + iter + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + PrimExpr lhs_expr, rhs_expr; + Optional lhs_iter, rhs_iter; + auto detect = [](const PrimExpr& binding) -> std::pair> { + arith::PVar expr; + arith::PVar iter; + if (iter.Match(binding)) { + return std::make_pair(0, iter.Eval()); + } else if ((expr + iter).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else if ((iter + expr).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else { + return std::make_pair(expr.Eval(), NullOpt); + } + }; + std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); + std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); + CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; + if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); + if (is_zero(rhs_expr)) { + CHECK(is_zero(lhs_expr)) << "Incompatible binding"; + } else { + const auto* bv = rhs_expr.as(); + if (!bv) { + VisitExpr(lhs_expr, rhs_expr); + } else { + auto it = equal_map_.find(GetRef(bv)); + if (it == equal_map_.end()) { + equal_map_[GetRef(bv)] = lhs_expr; + } else { + CHECK(it->second->IsInstance()); + VisitExpr(lhs_expr, Downcast(it->second)); + } + } + } + } + } else { + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; + } + const Block& block = op->block; + for (size_t i = 0; i < offset; ++i) { + Var block_var = Downcast(op->iter_values[i]); + auto it = equal_map_.find(block_var); + equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); + } + } + } + + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order shoudl match. + // When checking iter vars, DefEqual is used to remap variables. + // Only the inner most several axis are compared. Other iter vars are added to extra_block_vars. + if (op->iter_vars.size() < rhs->iter_vars.size()) return false; + + if (is_scope_block) { + //lhs_scope_block = op; + } + // size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); + // for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { + // auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; + // // Skip iter dom + // if (!DefEqual(lhs_var->var, rhs_var->var)) { + // return false; + // } + // if (lhs_var->iter_type != rhs_var->iter_type) { + // return false; + // } + // } + + // if (is_scope_block) { + // for (size_t i = 0; i < offset; ++i) { + // extra_block_vars_.push_back(op->iter_vars[i]); + // } + // } + + // if (!is_scope_block) { + // if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + // return false; + // } + // if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + // return false; + // } + // if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + // return false; + // } + // if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + // return false; + // } + // } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (!CompareType(op->dtype, rhs->dtype)) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + equal_map_[lhs] = rhs; + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret; + ret.reserve(map.size()); + for (const auto& pair : map) { + ret.emplace_back(pair); + } + sort(ret.begin(), ret.end()); + return ret; + }; + + auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + // Remap both buffer itself and buffer data + // Skip buffer shape + bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && + CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } else if (assert_mode_) { + LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs << lhs->dtype << rhs->dtype<buffer, rhs->buffer)) return false; + // Number of indices in desc_block must be smaller than it in AST + if (rhs->region.size() > lhs->region.size()) return false; + + std::vector lhs_region; + for (const auto& range : lhs->region) { + lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); + } + size_t offset = lhs_region.size() - rhs->region.size(); + // initialize buffer indices + bool need_update = false; + if (!buffer_indices_.count(lhs->buffer)) { + need_update = true; + buffer_indices_[lhs->buffer] = std::vector(); + } else { + if (offset != buffer_indices_[lhs->buffer].size()) return false; + } + std::vector& indices = buffer_indices_[lhs->buffer]; + for (size_t i = 0; i < offset; ++i) { + const Range& range = lhs_region[i]; + // High-dim region must be element-wise + if (!is_one(range->extent)) return false; + if (need_update) { + indices.push_back(range->min); + } else { + // The order matters since we only map inner block_var to outside block_var + if (!VisitExpr(range->min, indices[i])) return false; + } + } + for (size_t i = 0; i < rhs->region.size(); ++i) { + if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; + } + return true; +} + +// Comparator for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + if (rhs->indices.size() > lhs->indices.size()) return false; + // otherwise compare the leading indices + size_t offset = lhs->indices.size() - rhs->indices.size(); + for (size_t i = 0; i < rhs->indices.size(); ++i) { + if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { + if (lhs == rhs) return true; + return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h new file mode 100644 index 000000000000..11e458ec3c1b --- /dev/null +++ b/src/tir/schedule/ir_comparator.h @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ + +#include "./utils.h" + +namespace tvm { +namespace tir { + +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +/*! \brief Deep comparison to check if two IR ASTs are equivalent for tensorization*/ +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + /*! + * \brief Constructor of TensorizeComparator + * \param assert_mode Whether to raise an error if the two IR ASTs do not match. + * \param lhs_mod The IRModule of the LHS. This is used for error reporting. + */ + explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) { + } + + // Map from rhs buffer to lhs buffer + std::unordered_map rhs_buffer_map_; + // Buffer indices mapping + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + std::vector extra_block_vars_; + // variable remap if any + std::unordered_map equal_map_; + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareType(const DataType& lhs, const DataType& rhs); + + protected: + IRModule lhs_mod_; + // Map> lhs_root_access_region_; + BlockNode* lhs_scope_block; + bool assert_mode_; + bool is_scope_block = true; + bool is_inner_block = true; + arith::Analyzer analyzer; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 1f3032d10b2a..68c74413a219 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -388,12 +388,12 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); /*! - * \brief Tensorize the computation enclosed by loop with tensor_intrin. + * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. * \param self The state of the schedule - * \param loop_sref The loop to be tensorized. + * \param block_or_loop_sref The block or loop to be tensorized. * \param intrin The tensor intrinsic. */ -TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrin); +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const TensorIntrin& intrin); /******** Schedule: Annotation ********/ /*! diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 1d536c0f1bb5..3614d5688435 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -19,6 +19,7 @@ #include #include "../utils.h" +#include "../ir_comparator.h" namespace tvm { namespace tir { @@ -469,69 +470,72 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { } /*! - * \brief Update the map from the buffers in the description to the implementation of the tensor + * \brief Update the map from the buffers in the desc to the impl of the tensor * intrinsic. \param intrinsic The tensor intrinsic. \param buffer_map The map to be updated. */ void RemapTensorIntrinBuffers( const TensorIntrin& intrinsic, std::unordered_map* buffer_map) { - ICHECK_EQ(intrinsic->description->params.size(), intrinsic->implementation->params.size()); - for (size_t i = 0; i < intrinsic->description->params.size(); ++i) { - const auto& lhs_var = intrinsic->description->params[i]; - const auto& lhs_buffer = intrinsic->description->buffer_map[lhs_var]; - const auto& rhs_var = intrinsic->implementation->params[i]; - const auto& rhs_buffer = intrinsic->implementation->buffer_map[rhs_var]; + ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size()); + for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) { + const auto& lhs_var = intrinsic->desc->params[i]; + const auto& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; + const auto& rhs_var = intrinsic->impl->params[i]; + const auto& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; (*buffer_map)[rhs_buffer] = lhs_buffer; } } -void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrinsic) { +void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const TensorIntrin& intrinsic) { /*! * Check: * - Check buffer binding, including type, alignment, shape and etc. - * - Check the sub AST is equal to the description function. + * - Check the sub AST is equal to the desc function. * * Mutate: * - Blockize the sub AST (please refer blockize for details) * - Bind buffers - * - Mutate the implementation of the tensor intrinsic by replacing its buffers with new + * - Mutate the impl of the tensor intrinsic by replacing its buffers with new * buffers created via match buffer region. * - Replace the sub tree with the mutated function. */ - const auto* loop = loop_sref->StmtAs(); - CHECK(loop) << "Only support tensorize a loop for now"; - const auto* desc_block_realize = - Downcast(intrinsic->description->body)->block->body.as(); + Downcast(intrinsic->desc->body)->block->body.as(); const Block& desc_block = desc_block_realize->block; const auto* impl_block_realize = - Downcast(intrinsic->implementation->body)->block->body.as(); + Downcast(intrinsic->impl->body)->block->body.as(); Block impl_block = impl_block_realize->block; - // Step 1: Blockize the subtree rooted at the given loop - const StmtSRef& block_sref = Blockize(self, loop_sref); + // Step 1: Blockize the subtree rooted at the given loop if needed + StmtSRef block_sref{nullptr}; + if (const auto* loop = block_or_loop_sref->StmtAs()) { + block_sref = Blockize(self, block_or_loop_sref); + } else { + ICHECK(block_or_loop_sref->StmtAs()); + block_sref = block_or_loop_sref; + } const BlockRealize& block_realize = GetBlockRealize(self, block_sref); - // Step 2: Compare the block with the description of the tensor intrinsic, find the correspondence - // between buffers in the block and the description. - TensorizeComparator comparator(/*assert_mode=*/true); + // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence + // between buffers in the block and the desc. + TensorizeComparator comparator(self->mod, /*assert_mode=*/true); comparator.VisitStmt(block_realize, GetRef(desc_block_realize)); - // Step 3: Find the correspondence between buffers in the current AST and the implementation of + // Step 3: Find the correspondence between buffers in the current AST and the impl of // the tensor intrinsic - // Step 3.1: Map from intrinsic func buffer to description func buffer + // Step 3.1: Map from intrinsic func buffer to desc func buffer std::unordered_map intrin_buffer_map; RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); // Step 3.2: Map form intrinsic func buffer to current AST buffer std::unordered_map buffer_map; for (const auto& pair : intrin_buffer_map) { auto it = comparator.rhs_buffer_map_.find(pair.second); - ICHECK(it != comparator.rhs_buffer_map_.end()); + ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second; buffer_map[pair.first] = it->second; } - // Step 4: Create MatchBufferRegion for the params of the implementation function of the tensor + // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor // intrin to make them subregions of the buffer in the original IR. std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; for (const auto& read : impl_block->reads) { @@ -541,9 +545,9 @@ void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin buffer_region_map.emplace(write->buffer, write->region); } Array match_buffer_regions; - for (size_t i = 0; i < intrinsic->implementation->params.size(); ++i) { - const auto& param = intrinsic->implementation->params[i]; - const auto& buffer = intrinsic->implementation->buffer_map.at(param); + for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) { + const auto& param = intrinsic->impl->params[i]; + const auto& buffer = intrinsic->impl->buffer_map.at(param); const auto& source = buffer_map.at(buffer); Region region = buffer_region_map.at(buffer); auto extra_indices = comparator.buffer_indices_.at(source); @@ -555,7 +559,7 @@ void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, region))); } - // Step 5: Replace the subtree in the original IR with the tensor intrin implementation. + // Step 5: Replace the subtree in the original IR with the tensor intrin impl. ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); new_block_ptr->body = impl_block->body; ICHECK(new_block_ptr->match_buffers.empty()); @@ -616,14 +620,21 @@ struct TensorizeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String intrin_name) { - return sch->Tensorize(loop_rv, intrin_name); + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin) { + if (const auto* block = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(block), intrin); + } else if (const auto* loop = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(loop), intrin); + } else { + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } } - static String UnpackedAsPython(Array outputs, String loop_rv, String intrin_name) { + static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { PythonAPICall py("tensorize"); - py.Input("loop", loop_rv); - py.Input("intrin", intrin_name); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("intrin", intrin); return py.Str(); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 2736375dc591..b466843f9459 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -188,15 +188,15 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") - .set_body_typed([](Schedule self, LoopRV loop_rv, ObjectRef intrin) { - if (const auto* str = intrin.as()) { - return self->Tensorize(loop_rv, GetRef(str)); - } - if (const auto* p_intrin = intrin.as()) { - return self->Tensorize(loop_rv, GetRef(p_intrin)); + .set_body_typed([](Schedule self, ObjectRef rv, String intrin) { + if (const auto* block_rv = rv.as()) { + self->Tensorize(GetRef(block_rv), intrin); + } else if (const auto* loop_rv = rv.as()) { + self->Tensorize(GetRef(loop_rv), intrin); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; } - LOG(FATAL) << "TypeError: Cannot handle type: " << intrin->GetTypeKey(); - throw; }); /******** (FFI) Annotation ********/ diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index b4efbcb960d1..3a37f81b5dbc 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -201,8 +201,7 @@ class BlockInfoCollector : private StmtVisitor { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` Array child_block_srefs = std::move(block_frames_.back()); - self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); - BlockInfo& info = self_->block_info[scope_root]; + BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b888833e2385..1e2e57eb6eca 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -367,19 +367,24 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { return new_block; } -void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { - ConcreteScheduleNode::Tensorize(loop_rv, intrin_name); +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{intrin_name}, + /*attrs=*/{intrin}, /*outputs=*/{})); } -void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) { - LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the " - "tensor intrin and pass the intrin name instead."; +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(block_rv, intrin); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{intrin}, + /*outputs=*/{})); } /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 1cba79a1ee47..3a88e869d309 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -88,8 +88,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv) final; - void Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) final; - void Tensorize(const LoopRV& loop_rv, const String& intrin_name) final; + void Tensorize(const BlockRV& block_rv, const String& intrin) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; From 420d3629b1a3644dc108a5585039f05aa4ccaa44 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 19 Jan 2022 19:13:27 -0500 Subject: [PATCH 09/20] address comments and changed tensorized comparator --- python/tvm/tir/schedule/schedule.py | 12 +- src/tir/ir/function.cc | 22 -- src/tir/schedule/ir_comparator.cc | 278 ++++++++---------- src/tir/schedule/ir_comparator.h | 37 +-- .../schedule/primitive/blockize_tensorize.cc | 242 +++++++++------ .../unittest/test_tir_schedule_blockize.py | 58 ++-- .../unittest/test_tir_schedule_tensorize.py | 81 +++-- 7 files changed, 351 insertions(+), 379 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c7caf6737549..0d432877ed7f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1813,17 +1813,17 @@ def after_blockize( B: T.Buffer[(128, 128), "float32"] )-> None: for i_0, j_0 in T.grid(8, 8): - with T.block("blockized_B"): + with T.block("B_o"): vio, vjo = T.axis.remap("SS", [i_0, j_0]) T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) for i_1, j_1 in T.grid(16, 16): with T.block("B"): - vi = T.axis.spatial(128, vio * 16 + i_1) - vj = T.axis.spatial(128, vjo * 16 + j_1) - T.reads(A[vi, vj]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * T.float32(2) + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * + T.float32(2) Note ---- diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 3b2b4057aff4..3b35cd45b8ba 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -83,28 +83,6 @@ TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { } ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); - // check both functions' bodies are directly block - const auto* desc_realize = - Downcast(desc->body)->block->body.as(); - const auto* impl_realize = - Downcast(impl->body)->block->body.as(); - CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; - CHECK(impl_realize != nullptr) << "intrinsic function's body expect a directly block"; - - const Block& desc_block = desc_realize->block; - const Block& impl_block = impl_realize->block; - - // check block var number and iter type - CHECK_EQ(desc_block->iter_vars.size(), impl_block->iter_vars.size()) - << "ValueError: The blocks in the description and the implementation should have the same number of block vars"; - for (size_t i = 0; i < desc_block->iter_vars.size(); i++) { - const IterVar& desc_var = desc_block->iter_vars[i]; - const IterVar& impl_var = impl_block->iter_vars[i]; - CHECK(desc_var->iter_type == impl_var->iter_type) - << "Block iter_type mismatch between " << IterVarType2String(desc_var->iter_type) << " and " - << IterVarType2String(impl_var->iter_type); - } - ObjectPtr n = make_object(); n->desc = std::move(desc); n->impl = std::move(impl); diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 35c9e7098a31..2d11411f6735 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -26,8 +26,12 @@ namespace tir { class TensorIntrinMismatchError : public ScheduleError { public: - explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt) : - lhs_mod_(std::move(lhs_mod)), lhs_stmt_(std::move(lhs_stmt)), rhs_stmt_(std::move(rhs_stmt)) { + explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt, + std::vector error_messages) + : lhs_mod_(std::move(lhs_mod)), + lhs_stmt_(std::move(lhs_stmt)), + rhs_stmt_(std::move(rhs_stmt)), + error_messages_(std::move(error_messages)) { ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); } @@ -38,6 +42,9 @@ class TensorIntrinMismatchError : public ScheduleError { String DetailRenderTemplate() const final { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_; + for (const auto& msg : error_messages_) { + os << msg << std::endl; + } return os.str(); } @@ -49,19 +56,29 @@ class TensorIntrinMismatchError : public ScheduleError { IRModule lhs_mod_; Stmt lhs_stmt_; Stmt rhs_stmt_; + std::vector error_messages_; }; /* Override the dispatcher to make sure RHS is always valid */ bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { - bool equal = (n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other); + bool equal = n.same_as(other) || + ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other)); if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { - throw TensorIntrinMismatchError(lhs_mod_, n, other); + throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_)); } return equal; } bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { - return (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); + bool equal = + n.same_as(other) || ((n->type_index() == other->type_index()) && n->dtype == other->dtype && + ExprComparator::VisitExpr(n, other)); + if (!equal && assert_mode_) { + std::ostringstream os; + os << "Expression mismatch: " << n << " vs " << other; + EmitError(os.str()); + } + return equal; } bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { @@ -91,109 +108,36 @@ bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& othe bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { const auto* rhs = other.as(); - // Skip Compare binding values if the block is scope block (the outermost one). if (!is_scope_block) { - size_t offset = op->iter_values.size() - rhs->iter_values.size(); - if (rhs->iter_values.size() > op->iter_values.size()) return false; - if (is_inner_block) { - // weak pattern matching for the inner block (the son of the scope block) - // where the pattern is v + iter <=> expr + iter - for (size_t i = 0; i < rhs->iter_values.size(); ++i) { - PrimExpr lhs_expr, rhs_expr; - Optional lhs_iter, rhs_iter; - auto detect = [](const PrimExpr& binding) -> std::pair> { - arith::PVar expr; - arith::PVar iter; - if (iter.Match(binding)) { - return std::make_pair(0, iter.Eval()); - } else if ((expr + iter).Match(binding)) { - return std::make_pair(expr.Eval(), iter.Eval()); - } else if ((iter + expr).Match(binding)) { - return std::make_pair(expr.Eval(), iter.Eval()); - } else { - return std::make_pair(expr.Eval(), NullOpt); - } - }; - std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); - std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); - CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; - if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); - if (is_zero(rhs_expr)) { - CHECK(is_zero(lhs_expr)) << "Incompatible binding"; - } else { - const auto* bv = rhs_expr.as(); - if (!bv) { - VisitExpr(lhs_expr, rhs_expr); - } else { - auto it = equal_map_.find(GetRef(bv)); - if (it == equal_map_.end()) { - equal_map_[GetRef(bv)] = lhs_expr; - } else { - CHECK(it->second->IsInstance()); - VisitExpr(lhs_expr, Downcast(it->second)); - } - } - } - } - } else { - for (size_t i = 0; i < rhs->iter_values.size(); ++i) { - if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; - } - const Block& block = op->block; - for (size_t i = 0; i < offset; ++i) { - Var block_var = Downcast(op->iter_values[i]); - auto it = equal_map_.find(block_var); - equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); - } + if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { + return false; } } - return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); } bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); // Check block equality. - // All iter vars and buffer regions including the order shoudl match. + // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. - // Only the inner most several axis are compared. Other iter vars are added to extra_block_vars. - if (op->iter_vars.size() < rhs->iter_vars.size()) return false; - - if (is_scope_block) { - //lhs_scope_block = op; + if (!is_scope_block) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; } - // size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); - // for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { - // auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; - // // Skip iter dom - // if (!DefEqual(lhs_var->var, rhs_var->var)) { - // return false; - // } - // if (lhs_var->iter_type != rhs_var->iter_type) { - // return false; - // } - // } - - // if (is_scope_block) { - // for (size_t i = 0; i < offset; ++i) { - // extra_block_vars_.push_back(op->iter_vars[i]); - // } - // } - - // if (!is_scope_block) { - // if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { - // return false; - // } - // if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { - // return false; - // } - // if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { - // return false; - // } - // if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { - // return false; - // } - // } is_scope_block = false; return VisitStmt(op->body, rhs->body); } @@ -225,24 +169,24 @@ TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; + return op->value == rhs->value; } bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; + return op->value == rhs->value; } bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); + return VisitExpr(op->value, rhs->value); } bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { const auto* rhs = other.as(); auto lhs = GetRef(op); if (lhs.same_as(other)) return true; - if (!CompareType(op->dtype, rhs->dtype)) return false; + if (op->dtype != rhs->dtype) return false; auto it = equal_map_.find(lhs); return it != equal_map_.end() && it->second.same_as(other); } @@ -252,13 +196,14 @@ bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& o return CompareBufferAccess(op, rhs); } -bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { +bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { if (lhs.same_as(rhs)) return true; - if (lhs->type_index() != rhs->type_index()) return false; auto it = equal_map_.find(lhs); // If there is already a mapping if (it != equal_map_.end()) return it->second.same_as(rhs); + // Otherwise remap lhs to rhs equal_map_[lhs] = rhs; + analyzer_.Bind(lhs, rhs); return true; } @@ -276,16 +221,13 @@ bool TensorizeComparator::CompareAnnotationMap(const Map& lhs auto sort_map = [](const Map& map) -> std::vector> { - std::vector> ret; - ret.reserve(map.size()); - for (const auto& pair : map) { - ret.emplace_back(pair); - } + std::vector> ret(map.begin(), map.end()); sort(ret.begin(), ret.end()); return ret; }; - auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); + std::vector> lhs_array = sort_map(lhs); + std::vector> rhs_array = sort_map(rhs); for (size_t i = 0; i < lhs.size(); ++i) { if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; @@ -295,51 +237,65 @@ bool TensorizeComparator::CompareAnnotationMap(const Map& lhs bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { if (lhs.same_as(rhs)) return true; - // Remap both buffer itself and buffer data - // Skip buffer shape - bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && - CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); - if (equal) { - rhs_buffer_map_[rhs] = lhs; - } else if (assert_mode_) { - LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs << lhs->dtype << rhs->dtype<data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } } return equal; } bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { - // Only for block region declaration if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; - // Number of indices in desc_block must be smaller than it in AST - if (rhs->region.size() > lhs->region.size()) return false; - - std::vector lhs_region; - for (const auto& range : lhs->region) { - lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); - } - size_t offset = lhs_region.size() - rhs->region.size(); - // initialize buffer indices - bool need_update = false; - if (!buffer_indices_.count(lhs->buffer)) { - need_update = true; - buffer_indices_[lhs->buffer] = std::vector(); + int offset = static_cast(lhs->region.size()) - static_cast(rhs->region.size()); + // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS + if (offset < 0) return false; + + auto it = buffer_indices_.find(lhs->buffer); + if (it == buffer_indices_.end()) { + // Update base indices for the buffer, this can only happen if it is visiting the scope block. + ICHECK(is_scope_block); + std::vector indices_base; + indices_base.reserve(lhs->region.size()); + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + indices_base.emplace_back(lhs->region[i]->min); + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // save base index + indices_base.emplace_back(lhs->region[i + offset]->min); + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + } + buffer_indices_.emplace(lhs->buffer, std::move(indices_base)); } else { - if (offset != buffer_indices_[lhs->buffer].size()) return false; - } - std::vector& indices = buffer_indices_[lhs->buffer]; - for (size_t i = 0; i < offset; ++i) { - const Range& range = lhs_region[i]; - // High-dim region must be element-wise - if (!is_one(range->extent)) return false; - if (need_update) { - indices.push_back(range->min); - } else { - // The order matters since we only map inner block_var to outside block_var - if (!VisitExpr(range->min, indices[i])) return false; + // Check the base indices are consistent. + const std::vector& indices_base = it->second; + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false; + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { + return false; + } } - } - for (size_t i = 0; i < rhs->region.size(); ++i) { - if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; } return true; } @@ -348,11 +304,22 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf template bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; - if (rhs->indices.size() > lhs->indices.size()) return false; - // otherwise compare the leading indices - size_t offset = lhs->indices.size() - rhs->indices.size(); - for (size_t i = 0; i < rhs->indices.size(); ++i) { - if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; + int offset = static_cast(lhs->indices.size()) - static_cast(rhs->indices.size()); + if (offset < 0) return false; + auto it = buffer_indices_.find(lhs->buffer); + ICHECK(it != buffer_indices_.end()); + const std::vector& indices_base = (*it).second; + ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); + for (size_t i = 0; i < rhs->indices.size(); i++) { + PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; + if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs " << rhs->indices[i]; + EmitError(os.str()); + } + return false; + } } return true; } @@ -371,9 +338,12 @@ bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); } -bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { - if (lhs == rhs) return true; - return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); +bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& rhs) { + return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type; +} + +void TensorizeComparator::EmitError(const std::string& error_message) { + error_messages_.push_back(error_message); } } // namespace tir diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 11e458ec3c1b..b81cd6a4a6f1 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -35,16 +35,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { * \param assert_mode Whether to raise an error if the two IR ASTs do not match. * \param lhs_mod The IRModule of the LHS. This is used for error reporting. */ - explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) { - } - - // Map from rhs buffer to lhs buffer - std::unordered_map rhs_buffer_map_; - // Buffer indices mapping - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; - std::vector extra_block_vars_; - // variable remap if any - std::unordered_map equal_map_; + explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) + : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {} bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; bool VisitStmt(const Stmt& n, const Stmt& other) override; @@ -78,7 +70,13 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; - bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + /*! \brief Map from RHS buffer to LHS buffer */ + std::unordered_map rhs_buffer_map_; + /*! \brief Base indices of the LHS buffer. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + + protected: + bool DefEqual(const Var& lhs, const Var& rhs); virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); bool CompareAnnotation(const std::pair& lhs, @@ -89,16 +87,21 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { template bool CompareArray(const Array& lhs, const Array& rhs, F cmp); bool CompareRange(const Range& lhs, const Range& rhs); - bool CompareType(const DataType& lhs, const DataType& rhs); + bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); + void EmitError(const std::string& error_message); - protected: + /*! \brief IRModule of the LHS stmt. */ IRModule lhs_mod_; - // Map> lhs_root_access_region_; - BlockNode* lhs_scope_block; + /*! \brief Whether assertion mode is enabled. */ bool assert_mode_; + /*! \brief Whether it is visiting the scope block (the outermost block). */ bool is_scope_block = true; - bool is_inner_block = true; - arith::Analyzer analyzer; + /*! \brief The arithmetic analyzer. */ + arith::Analyzer analyzer_; + /*! \brief Additional error messages. Only used when assert_mode is true. */ + std::vector error_messages_; + // variable remap if any + std::unordered_map equal_map_; }; } // namespace tir diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 3614d5688435..59e9779bb290 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -18,8 +18,8 @@ */ #include -#include "../utils.h" #include "../ir_comparator.h" +#include "../utils.h" namespace tvm { namespace tir { @@ -66,7 +66,7 @@ class SubspaceNotDivisibleError : public ScheduleError { * \param bindings The values of iter_vars * \param outer_loops Iterators outside the subspace. * \param inner_loops Iterators of the subspace - * \param predicate The predicate constaints on the input iterators. + * \param predicate The predicate constraint on the input iterators. * \return The result of the subspace division. */ Array> TrivialSubspaceDivision(const Array& iter_vars, @@ -75,16 +75,25 @@ Array> TrivialSubspaceDivision(const Array& iter const Array& inner_iters, const PrimExpr& predicate) { if (!is_one(predicate)) return {}; - std::vector> res; + Array> res; std::unordered_set outer_loop_vars; std::unordered_set inner_loop_vars; - for (const Var& var : outer_iters) { - outer_loop_vars.insert(var.get()); - } - for (const Var& var : inner_iters) { - inner_loop_vars.insert(var.get()); - } - const arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); + + auto make_uses_var = [](const Array& vars) -> std::function { + std::unordered_set var_set; + var_set.reserve(vars.size()); + for (const Var& var : vars) { + var_set.insert(var.get()); + } + return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&var_set](const VarNode* var) { + return var_set.count(var); // + }); + }; + }; + auto use_outer_loop_vars = make_uses_var(outer_iters); + auto use_inner_loop_vars = make_uses_var(inner_iters); + arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); for (size_t i = 0; i < bindings.size(); ++i) { bool outer = UsesVar( @@ -100,18 +109,11 @@ Array> TrivialSubspaceDivision(const Array& iter iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); } if (outer && !inner) { - arith::IterMark outer{nullptr}; - const auto& outer_iter = iter_mark; - const auto& inner_iter = unit_iter_mark; - res.push_back({outer_iter, inner_iter}); + res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark}); } else if (inner && !outer) { - const auto& outer_iter = unit_iter_mark; - const auto& inner_iter = iter_mark; - res.push_back({outer_iter, inner_iter}); + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark}); } else if (!outer && !inner) { - const auto& outer_iter = unit_iter_mark; - const auto& inner_iter = unit_iter_mark; - res.push_back({outer_iter, inner_iter}); + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/unit_iter_mark}); } else { return {}; } @@ -140,7 +142,7 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r const PrimExpr& binding = inner_block_realize->iter_values[i]; if (iter_var->iter_type == IterVarType::kDataPar && UsesVar(block->init.value(), - [&iter_var](const VarNode* var) { return var == iter_var->var.get(); })) { + [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) { init_block_iters.push_back(iter_var); init_bindings.push_back(binding); } @@ -151,8 +153,9 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r for (const ForNode* inner_loop : inner_loops) { for (const PrimExpr& init_binding : init_bindings) { if (UsesVar(init_binding, - [inner_loop](const VarNode* var) { return var == inner_loop->loop_var.get(); })) { + [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { return var == tgt_var; })) { init_loops.push_back(inner_loop); + break; } } } @@ -169,16 +172,16 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r } // Step 4: Generate loop nests and the init block - Block init_block{/*iter_vars=*/init_block_iters, // - /*reads=*/{}, // - /*writes=*/block->writes, // - /*name_hint=*/block->name_hint + "_init", // - /*body=*/block->init.value(), // - /*init=*/NullOpt}; Stmt new_init = BlockRealize( /*iter_values=*/init_bindings, /*predicate=*/inner_block_realize->predicate, - /*block=*/std::move(init_block)); + /*block=*/ + Block{/*iter_vars=*/init_block_iters, + /*reads=*/{}, + /*writes=*/block->writes, + /*name_hint=*/block->name_hint + "_init", + /*body=*/block->init.value(), + /*init=*/NullOpt}); // Step 5: Generate the parent loops for the init block for (const ForNode* init_loop : init_loops) { @@ -207,7 +210,7 @@ class LoopSubspaceCollector { /*! * \brief Collect the parent loops of the block and store the result in the corresponding fields. * \param block_sref The sref to the target block. - * \param loop_sref The sref to the separator loop. + * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop. */ void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { bool inner = true; @@ -287,16 +290,14 @@ class BlockizedBindingExtractor { * \param division The result of the subspace division. */ void ExtractBindings(const Array& iter_vars, - const Array>& division) { - ICHECK(iter_vars.size() + 1 == division.size()); + const Array>& division, arith::Analyzer* analyzer) { + ICHECK_EQ(iter_vars.size() + 1, division.size()); for (size_t i = 0; i < iter_vars.size(); ++i) { const IterVar& iter_var = iter_vars[i]; - const arith::IterMapExprNode* outer_binding = - division[i][0]->source.as(); - const arith::IterMapExprNode* inner_binding = - division[i][1]->source.as(); - ICHECK(outer_binding); - ICHECK(inner_binding); + arith::IterMark outer_mark = division[i][0]; + arith::IterMark inner_mark = division[i][1]; + const auto* outer_binding = TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); + const auto* inner_binding = TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); // After computing the subspace division, bindings[i] can be written as // outer_binding * inner_binding->extent + inner_binding @@ -316,20 +317,21 @@ class BlockizedBindingExtractor { arith::NormalizeIterMapToExpr(GetRef(outer_binding))); outer_iter_vars.push_back(outer_var); PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; - if (const auto* op = division[i][1]->source.as()) { - base = base + op->base; - inner_bindings.push_back(base + - arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); - } else { - inner_bindings.push_back( - base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); - } - inner_iter_vars.push_back(iter_var); inner_iter_relaxed_range.Set(iter_var->var, arith::IntSet::FromMinExtent(base, division[i][1]->extent)); + IterVar new_iter = iter_var; + auto* new_iter_node = new_iter.CopyOnWrite(); + new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent); + analyzer->Bind(new_iter->var, new_iter->dom); + // new_iter_node->dom + inner_iter_vars.push_back(new_iter); + inner_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + inner_iter_subst_map.Set(iter_var->var, base + new_iter->var); } } } + Map inner_iter_subst_map; /*! \brief Iters of the outer block. */ Array outer_iter_vars; /*! \brief Iters of the outer block. */ @@ -345,6 +347,54 @@ class BlockizedBindingExtractor { Map inner_iter_relaxed_range; }; +/*! + * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with + * base + inner_iter and the expressions after substituion will be simplified if possible. + */ +class InnerIterReplacer : public StmtExprMutator { + public: + /*! + * \brief The constructor + * \param subst_map The substitution map of the inner block iters. + * \param analyzer The arithmetic analyzer. + * \param block_sref_reuse The map to save the block reuse information. + */ + InnerIterReplacer(Map subst_map, arith::Analyzer* analyzer, + Map* block_sref_reuse) + : subst_map_(std::move(subst_map)), + analyzer_(analyzer), + block_sref_reuse_(block_sref_reuse) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = subst_map_.find(GetRef(op)); + if (it != subst_map_.end()) { + return (*it).second; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr(const PrimExpr& op) final { + PrimExpr result = StmtExprMutator::VisitExpr(op); + if (!result.same_as(op)) { + return analyzer_->Simplify(result); + } + return result; + } + + Stmt VisitStmt_(const BlockNode* op) final { + Stmt result = StmtExprMutator::VisitStmt_(op); + if (!result.same_as(GetRef(op))) { + block_sref_reuse_->Set(GetRef(op), Downcast(result)); + } + return result; + } + + private: + Map subst_map_; + arith::Analyzer* analyzer_; + Map* block_sref_reuse_; +}; + /*! * \brief Compute the access region of the outer block by relaxing the inner loops. * \param buffer_region The original buffer region. @@ -398,7 +448,7 @@ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extrac new_writes.MutateByApply(f_mutate); // Step 3: Generate the body of the outer block. The body of the outer block is the inner block - // realize and its surounding loops. + // realize and its surrounding loops. Stmt outer_block_body = inner_block_realize; for (const ForNode* loop : inner_loops) { ObjectPtr new_loop = make_object(*loop); @@ -407,16 +457,15 @@ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extrac } // Step 4: Generate the outer block and block realize. - Block outer_block{/*iter_vars=*/extractor.outer_iter_vars, // - /*reads=*/new_reads, // - /*writes=*/new_writes, // - /*name_hint=*/"blockized_" + block->name_hint, // - /*body=*/std::move(outer_block_body), // - /*init=*/new_init}; - BlockRealize outer_block_realize{/*iter_values=*/extractor.outer_bindings, - /*predicate=*/std::move(predicate), - /*block=*/std::move(outer_block)}; - return outer_block_realize; + return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings), + /*predicate=*/std::move(predicate), + /*block=*/ + Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), // + /*reads=*/std::move(new_reads), // + /*writes=*/std::move(new_writes), // + /*name_hint=*/block->name_hint + "_o", // + /*body=*/std::move(outer_block_body), // + /*init=*/std::move(new_init))); } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { @@ -439,54 +488,60 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { // Step 4: Generate bindings for the outer block and the inner block based on the result of // the subspace division. BlockizedBindingExtractor extractor; - extractor.ExtractBindings(block->iter_vars, division); + extractor.ExtractBindings(block->iter_vars, division, &analyzer); const PrimExpr& outer_pred = division.back()[0]->extent; const PrimExpr& inner_pred = division.back()[1]->extent; // Step 5: Generate the inner block. + Map block_sref_reuse; BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); inner_block_realize->iter_values = extractor.inner_bindings; inner_block_realize->predicate = inner_pred; inner_block->iter_vars = extractor.inner_iter_vars; inner_block->init = NullOpt; + InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, + &block_sref_reuse); + inner_block_realize->block = Downcast(replacer(inner_block_realize->block)); // Step 6: Generate the outer block. BlockRealize outer_realize = GenerateBlockizedOuterBlock(extractor, block, GetRef(inner_block_realize), collector.inner_loops, outer_pred, &analyzer); // Step 7: Do the actual replacement - self->Replace(loop_sref, outer_realize, {{block, GetRef(inner_block)}}); + self->Replace(loop_sref, outer_realize, block_sref_reuse); // Step 8: Update the cached flags - const StmtSRef& outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); + StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false, /*require_subtree_compact_dataflow=*/false); - BlockInfo old_block_info = self->GetBlockInfo(scope_root); + bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); - // 'affine_binding' depends on the outer loops and are not changed. - self->block_info[scope_root].affine_binding = old_block_info.affine_binding; + self->block_info[scope_root].affine_binding = scope_block_affine_binding; return outer_block_sref; } /*! * \brief Update the map from the buffers in the desc to the impl of the tensor - * intrinsic. \param intrinsic The tensor intrinsic. \param buffer_map The map to be updated. + * intrinsic. + * \param intrinsic The tensor intrinsic. + * \param buffer_map The map to be updated. */ void RemapTensorIntrinBuffers( const TensorIntrin& intrinsic, std::unordered_map* buffer_map) { ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size()); for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) { - const auto& lhs_var = intrinsic->desc->params[i]; - const auto& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; - const auto& rhs_var = intrinsic->impl->params[i]; - const auto& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; + const Var& lhs_var = intrinsic->desc->params[i]; + const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; + const Var& rhs_var = intrinsic->impl->params[i]; + const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; (*buffer_map)[rhs_buffer] = lhs_buffer; } } -void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const TensorIntrin& intrinsic) { +void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrinsic) { /*! * Check: * - Check buffer binding, including type, alignment, shape and etc. @@ -499,11 +554,8 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const Ten * buffers created via match buffer region. * - Replace the sub tree with the mutated function. */ - const auto* desc_block_realize = - Downcast(intrinsic->desc->body)->block->body.as(); - const Block& desc_block = desc_block_realize->block; - const auto* impl_block_realize = - Downcast(intrinsic->impl->body)->block->body.as(); + const BlockRealize& desc_block_realize = Downcast(intrinsic->desc->body); + const BlockRealize& impl_block_realize = Downcast(intrinsic->impl->body); Block impl_block = impl_block_realize->block; // Step 1: Blockize the subtree rooted at the given loop if needed @@ -519,11 +571,10 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const Ten // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence // between buffers in the block and the desc. TensorizeComparator comparator(self->mod, /*assert_mode=*/true); - comparator.VisitStmt(block_realize, GetRef(desc_block_realize)); + comparator.VisitStmt(block_realize, desc_block_realize); // Step 3: Find the correspondence between buffers in the current AST and the impl of // the tensor intrinsic - // Step 3.1: Map from intrinsic func buffer to desc func buffer std::unordered_map intrin_buffer_map; RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); @@ -545,18 +596,24 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const Ten buffer_region_map.emplace(write->buffer, write->region); } Array match_buffer_regions; + match_buffer_regions.reserve(intrinsic->impl->params.size()); for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) { const auto& param = intrinsic->impl->params[i]; const auto& buffer = intrinsic->impl->buffer_map.at(param); const auto& source = buffer_map.at(buffer); - Region region = buffer_region_map.at(buffer); - auto extra_indices = comparator.buffer_indices_.at(source); - std::vector extra_buffer_ranges; - std::transform(extra_indices.begin(), extra_indices.end(), - std::back_inserter(extra_buffer_ranges), - [](const PrimExpr& index) { return Range::FromMinExtent(index, 1); }); - region.insert(region.begin(), extra_buffer_ranges.begin(), extra_buffer_ranges.end()); - match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, region))); + // add the detected base indices to each buffer access region of the tensor intrinsic + Region old_region = buffer_region_map.at(buffer); + const auto& indices_base = comparator.buffer_indices_.at(source); + int offset = indices_base.size() - old_region.size(); + Region new_region; + new_region.reserve(source->shape.size()); + for (int i = 0; i < offset; i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i], 1)); + } + for (int i = 0; i < old_region.size(); i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent)); + } + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region))); } // Step 5: Replace the subtree in the original IR with the tensor intrin impl. @@ -566,17 +623,6 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const Ten new_block_ptr->match_buffers = std::move(match_buffer_regions); Block new_block(new_block_ptr); - // Substitution map for the intrin block to get the correct bindings. - Map bv_map; - for (size_t i = 0; i < desc_block->iter_vars.size(); ++i) { - auto it = comparator.equal_map_.find(desc_block->iter_vars[i]->var); - if (it != comparator.equal_map_.end()) { - bv_map.Set(impl_block->iter_vars[i]->var, Downcast(it->second)); - } else { - bv_map.Set(impl_block->iter_vars[i]->var, Integer(0)); - } - } - new_block = Downcast(SubstituteInScope(new_block, bv_map)); self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); // Step 6: Update the cached flags. @@ -626,8 +672,8 @@ struct TensorizeTraits : public UnpackedInstTraits { } else if (const auto* loop = block_or_loop_rv.as()) { sch->Tensorize(GetRef(loop), intrin); } else { - LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); - throw; + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); } } diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 84fd329899ac..b4a16a8231b8 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -62,7 +62,7 @@ def single_elementwise_blockized2( T.writes(B[vi, 0:128]) for j in T.serial(128): with T.block("B"): - vj = T.axis.spatial(128, j) + vj = T.axis.remap("S", [j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) @@ -90,38 +90,28 @@ def two_elementwise_blockized( A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"] ) -> None: - with T.block("root"): - T.reads([]) - T.writes([]) - B = T.alloc_buffer([128, 128], "float32") - for i0_outer in range(0, 8): - for i1_outer in range(0, 8): - with T.block("blockized_B"): - vio = T.axis.S(8, i0_outer) - vjo = T.axis.S(8, i1_outer) - T.reads([A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) - T.writes([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) - for i0_inner in range(0, 16): - for i1_inner in range(0, 16): - with T.block("B"): - vi = T.axis.S(128, ((vio * 16) + i0_inner)) - vj = T.axis.S(128, ((vjo * 16) + i1_inner)) - T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) - T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) - B[vi, vj] = A[vi, vj] * T.float32(2) - with T.block("blockized_C"): - vio = T.axis.S(8, i0_outer) - vjo = T.axis.S(8, i1_outer) - T.reads([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) - T.writes([C[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) - for ax0 in range(0, 16): - for ax1 in range(0, 16): - with T.block("C"): - vi = T.axis.S(128, ((vio * 16) + ax0)) - vj = T.axis.S(128, ((vjo * 16) + ax1)) - T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) - T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) - C[vi, vj] = B[vi, vj] + T.float32(1) + B = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0 in T.grid(8, 8): + with T.block("blockized_B"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2) + with T.block("blockized_C"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for ax0, ax1 in T.grid(16, 16): + with T.block("C"): + vi, vj = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[vio * 16 + vi, vjo * 16 + vj]) + T.writes(C[vio * 16 + vi, vjo * 16 + vj]) + C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1) @T.prim_func @@ -160,6 +150,7 @@ def test_blockize_outer(): B = s.get_block("B") x, y = s.get_loops(B) s.blockize(x) + print(s.mod['main'].script()) tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1) verify_trace_roundtrip(sch=s, mod=func) @@ -171,7 +162,6 @@ def test_blockize_inner(): B = s.get_block("B") x, y = s.get_loops(B) s.blockize(y) - print(s.mod["main"].script()) tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index a9eb6860f9cd..401a39f379b7 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -33,14 +33,11 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(16, 0) + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): - vii = T.axis.S(16, vi + i) - vjj = T.axis.S(16, vj + j) - vkk = T.axis.R(16, vk + k) + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] @@ -51,17 +48,8 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(16, 0) - T.reads( - [ - C[vi : vi + 16, vj : vj + 16], - A[vi : vi + 16, vk : vk + 16], - B[vj : vj + 16, vk : vk + 16], - ] - ) - T.writes(C[vi : vi + 16, vj : vj + 16]) + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) T.evaluate( T.tvm_mma_sync( C.data, @@ -84,10 +72,11 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, ()) with T.block("root"): - v0 = T.axis.R(4, 0) + T.reads(C[()], A[0 : 4], B[0 : 4]) + T.writes(C[()]) for i in range(0, 4): with T.block("update"): - vi = T.axis.R(4, v0 + i) + vi = T.axis.remap("R", [i]) C[()] = C[()] + A[vi] * B[vi] @@ -98,8 +87,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (), offset_factor=1) with T.block("root"): - v0 = T.axis.R(4, 0) - T.reads(C[()], A[v0 : v0 + 4], B[v0 : v0 + 4]) + T.reads(C[()], A[0 : 4], B[0 : 4]) T.writes(C[()]) T.evaluate( T.call_extern( @@ -117,43 +105,36 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 1)) - B = T.match_buffer(b, (16, 1)) - C = T.match_buffer(c, (16, 16)) + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(1, 0) T.reads( - C[vi : vi + 16, vj : vj + 16], - A[vi : vi + 16, vk : vk + 1], - B[vj : vj + 16, vk : vk + 1], + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], ) - T.writes(C[vi : vi + 16, vj : vj + 16]) + T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): with T.block("update"): - vii = T.axis.S(16, vi + i) - vjj = T.axis.S(16, vj + j) - C[vii, vjj] = C[vii, vjj] + A[vii, vk] * B[vjj, vk] + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0] @T.prim_func def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 1)) - B = T.match_buffer(b, (16, 1)) - C = T.match_buffer(c, (16, 16)) + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(1, 0) T.reads( - C[vi : vi + 16, vj : vj + 16], - A[vi : vi + 16, vk : vk + 1], - B[vj : vj + 16, vk : vk + 1], + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], ) - T.writes(C[vi : vi + 16, vj : vj + 16]) + T.writes(C[0 : 16, 0 : 16]) T.evaluate( T.call_extern( "outer_product", @@ -371,15 +352,19 @@ def tensorized_batch_matmul_outer_product( B[vn, vjo * 16 : vjo * 16 + 16, vk], ) T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) - A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, 0], [16, 1], dtype="float32") - B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, 0], [16, 1], dtype="float32") + A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1) + B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1 + ) C_1 = T.match_buffer( - C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32" + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32", offset_factor=1 ) T.evaluate( - T.call_extern("outer_product", C_1.data, 0, A_1.data, 0, B_1.data, 0, dtype="int32") + T.call_extern("outer_product", C_1.data, C_1.elem_offset, A_1.data, A_1.elem_offset, + B_1.data, B_1.elem_offset, dtype="int32" + ) ) + # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks From 5fe38af7e2b4ad78d096e9b5294b2c764db35ae6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 15:12:05 -0500 Subject: [PATCH 10/20] update --- python/tvm/tir/schedule/schedule.py | 91 ++++++++----------- .../schedule/primitive/blockize_tensorize.cc | 61 +++++++------ 2 files changed, 72 insertions(+), 80 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0d432877ed7f..d3a5c1ee919a 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1822,14 +1822,13 @@ def after_blockize( vi, vj = T.axis.remap("SS", [i_1, j_1]) T.reads(A[vio * 16 + vi, vjo * 16 + vj]) T.writes(B[vio * 16 + vi, vjo * 16 + vj]) - B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * - T.float32(2) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] \ + * T.float32(2) Note ---- blockize requires there is exactly one block under the given loop and the bindings of the block are divisible by the subspace represented by the loops starting at the given loop. - """ return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member @@ -1863,19 +1862,22 @@ def before_tensorize( for i_0, j_0 in T.grid(8, 8): for i_1_init, j_1_init in T.grid(16, 16): with T.block("init"): - vi = T.axis.spatial(128, i_0 * 16 + i_1_init) - vj = T.axis.spatial(128, j_0 * 16 + j_1_init) + vi, vj = T.axis.remap("SS", [i_1_init, j_1_init]) T.reads() - T.writes(C[vi, vj]) - C[vi, vj] = T.float32(0) + T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) + C[i_0 * 16 + vi, j_0 * 16 + vj] = T.float32(0) for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16): with T.block("update"): - vi = T.axis.spatial(128, i_0 * 16 + i_1) - vj = T.axis.spatial(128, j_0 * 16 + j_1) - vk = T.axis.reduce(128, k_0 * 16 + k_1) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(C[vi, vj]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + vi, vj, vk = T.axis.remap("SSR", [i_1, j_1, k_1]) + T.reads( + C[i_0 * 16 + vi, j_0 * 16 + vj], + A[i_0 * 16 + vi, k_0 * 16 + vk], + B[j_0 * 16 + vj, k_0 * 16 + vk] + ) + T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) + C[i_0 * 16 + vi, j_0 * 16 + vj] = C[i_0 * 16 + vi, j_0 * 16 + vj] + \ + A[i_0 * 16 + vi, k_0 * 16 + vk] * \ + B[j_0 * 16 + vj, k_0 * 16 + vk] Declare and register the tensor intrinsic: @@ -1888,15 +1890,12 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(16, 0) + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): - vii = T.axis.S(16, vi + i) - vjj = T.axis.S(16, vj + j) - vkk = T.axis.R(16, vk + k) - C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -1906,17 +1905,8 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): - vi = T.axis.S(16, 0) - vj = T.axis.S(16, 0) - vk = T.axis.R(16, 0) - T.reads( - [ - C[vi : vi + 16, vj : vj + 16], - A[vi : vi + 16, vk : vk + 16], - B[vj : vj + 16, vk : vk + 16], - ] - ) - T.writes(C[vi : vi + 16, vj : vj + 16]) + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) T.evaluate( T.tvm_mma_sync( C.data, @@ -1938,9 +1928,9 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_tensorize) - update = s.get_block("update") - _, _, _, i1, _, _ = s.get_loops(update) - s.tensorize(ii, "test_mma_intrin") + update = sch.get_block("update") + _, _, _, i1, _, _ = sch.get_loops(update) + sch.tensorize(i1, "test_mma_intrin") print(sch.mod["main"].script()) After applying tensorize, the IR becomes: @@ -1948,44 +1938,43 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: .. code-block:: python @T.prim_func - def after_tensoirze( + def after_tensorize( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: - # body - # with T.block("root") for i_0, j_0 in T.grid(8, 8): for i_1_init, j_1_init in T.grid(16, 16): with T.block("init"): - vi = T.axis.spatial(128, i_0 * 16 + i_1_init) - vj = T.axis.spatial(128, j_0 * 16 + j_1_init) + vi, vj = T.axis.remap("SS", [i_1_init, j_1_init]) T.reads() - T.writes(C[vi, vj]) - C[vi, vj] = T.float32(0) + T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) + C[i_0 * 16 + vi, j_0 * 16 + vj] = T.float32(0) for k_0 in T.serial(8): - with T.block("blockized_update"): - vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) + with T.block("update_o"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + vko = T.axis.reduce(1, 0) T.reads( - C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], - A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], - B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16], + A[i_0 * 16 : i_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], + B[j_0 * 16 : j_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], ) - T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16]) A_1 = T.match_buffer( - A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + A[i_0 * 16 : i_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], [16, 16], dtype="float32", offset_factor=1, ) B_1 = T.match_buffer( - B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + B[j_0 * 16 : j_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], [16, 16], dtype="float32", offset_factor=1, ) C_1 = T.match_buffer( - C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16], [16, 16], dtype="float32", offset_factor=1, @@ -2006,7 +1995,7 @@ def after_tensoirze( """ _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member - self, block_or_loop, tensor_intrin + self, block_or_loop, tensor_intrin ) ########## Schedule: Annotation ########## diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 59e9779bb290..59bf8b1b626f 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -152,8 +152,9 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r std::vector init_loops; for (const ForNode* inner_loop : inner_loops) { for (const PrimExpr& init_binding : init_bindings) { - if (UsesVar(init_binding, - [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { return var == tgt_var; })) { + if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { + return var == tgt_var; + })) { init_loops.push_back(inner_loop); break; } @@ -296,14 +297,17 @@ class BlockizedBindingExtractor { const IterVar& iter_var = iter_vars[i]; arith::IterMark outer_mark = division[i][0]; arith::IterMark inner_mark = division[i][1]; - const auto* outer_binding = TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); - const auto* inner_binding = TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); + const auto* outer_binding = + TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); + const auto* inner_binding = + TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); // After computing the subspace division, bindings[i] can be written as // outer_binding * inner_binding->extent + inner_binding // The outer block will have binding: iter_outer -> outer_binding - // The inner block will have binding: iter_inner -> iter_outer * inner_binding->extent + - // inner_binding + // The inner block will have binding: iter_inner -> inner_binding + // The iter in the original block will be substituted with base + iter_inner where + // base == iter_outer * iter_inner_extent if (is_one(division[i][1]->extent)) { // IsOuter // extract this iter var to outer block directly @@ -311,19 +315,19 @@ class BlockizedBindingExtractor { arith::NormalizeIterMapToExpr(GetRef(outer_binding))); outer_iter_vars.push_back(iter_var); } else { + // create iter var for the outer block const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), iter_var->var.copy_with_suffix("o"), iter_var->iter_type); outer_bindings.push_back( arith::NormalizeIterMapToExpr(GetRef(outer_binding))); outer_iter_vars.push_back(outer_var); PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; - inner_iter_relaxed_range.Set(iter_var->var, - arith::IntSet::FromMinExtent(base, division[i][1]->extent)); + // create iter var for the inner block IterVar new_iter = iter_var; auto* new_iter_node = new_iter.CopyOnWrite(); new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent); + inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom)); analyzer->Bind(new_iter->var, new_iter->dom); - // new_iter_node->dom inner_iter_vars.push_back(new_iter); inner_bindings.push_back( arith::NormalizeIterMapToExpr(GetRef(inner_binding))); @@ -340,11 +344,8 @@ class BlockizedBindingExtractor { Array outer_bindings; /*! \brief Binding values of the inner block. */ Array inner_bindings; - - /*! \brief The range of the inner block iters Note that this is different from the domain of the - * inner block iters. - */ - Map inner_iter_relaxed_range; + /*! \brief The domain of the inner block iters. */ + Map inner_iter_dom_map; }; /*! @@ -399,12 +400,10 @@ class InnerIterReplacer : public StmtExprMutator { * \brief Compute the access region of the outer block by relaxing the inner loops. * \param buffer_region The original buffer region. * \param The range of the inner loops. - * \param analyzer The arithmetic analyzer. * \return The new buffer region. */ BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, - const Map& inner_iter_relaxed_range, - arith::Analyzer* analyzer) { + const Map& inner_iter_relaxed_range) { Array new_region; new_region.reserve(buffer_region->region.size()); Array relaxed_int_set = @@ -424,13 +423,12 @@ BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, * \param inner_block_realize The block realize of the inner block after blockize. * \param inner_loops The inner loops after blockize. * \param predicate The outer predicate of the subspace division. - * \param analyzer The arithmetic analyzer. * \return The block realize of the outer block after blockize. */ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, const Block& block, BlockRealize inner_block_realize, const std::vector& inner_loops, - PrimExpr predicate, arith::Analyzer* analyzer) { + PrimExpr predicate) { // Step 1: Generate the init block if needed Optional new_init = NullOpt; if (block->init.defined()) { @@ -442,7 +440,7 @@ BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extrac Array new_writes = block->writes; auto f_mutate = [&](const BufferRegion& buffer_region) { - return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_relaxed_range, analyzer); + return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map); }; new_reads.MutateByApply(f_mutate); new_writes.MutateByApply(f_mutate); @@ -492,22 +490,27 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { const PrimExpr& outer_pred = division.back()[0]->extent; const PrimExpr& inner_pred = division.back()[1]->extent; - // Step 5: Generate the inner block. + // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace + // division Map block_sref_reuse; + InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, + &block_sref_reuse); + Block new_block = Downcast(replacer(block)); + + // Step 6: Generate the inner block. BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); - BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); inner_block_realize->iter_values = extractor.inner_bindings; inner_block_realize->predicate = inner_pred; + inner_block_realize->block = new_block; + BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); inner_block->iter_vars = extractor.inner_iter_vars; inner_block->init = NullOpt; - InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, - &block_sref_reuse); - inner_block_realize->block = Downcast(replacer(inner_block_realize->block)); + block_sref_reuse.Set(block, inner_block_realize->block); // Step 6: Generate the outer block. BlockRealize outer_realize = - GenerateBlockizedOuterBlock(extractor, block, GetRef(inner_block_realize), - collector.inner_loops, outer_pred, &analyzer); + GenerateBlockizedOuterBlock(extractor, new_block, GetRef(inner_block_realize), + collector.inner_loops, outer_pred); // Step 7: Do the actual replacement self->Replace(loop_sref, outer_realize, block_sref_reuse); @@ -589,10 +592,10 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor // intrin to make them subregions of the buffer in the original IR. std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; - for (const auto& read : impl_block->reads) { + for (const BufferRegion& read : impl_block->reads) { buffer_region_map.emplace(read->buffer, read->region); } - for (const auto& write : impl_block->writes) { + for (const BufferRegion& write : impl_block->writes) { buffer_region_map.emplace(write->buffer, write->region); } Array match_buffer_regions; From f03f88aff4436c2c4375fff8c4c7dc7a1f4ce7a5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 15:29:07 -0500 Subject: [PATCH 11/20] nit --- src/tir/schedule/primitive/blockize_tensorize.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 59bf8b1b626f..70fbd6e2038d 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -316,8 +316,9 @@ class BlockizedBindingExtractor { outer_iter_vars.push_back(iter_var); } else { // create iter var for the outer block - const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), - iter_var->var.copy_with_suffix("o"), iter_var->iter_type); + const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent), + /*var=*/iter_var->var.copy_with_suffix("_o"), + /*iter_type=*/iter_var->iter_type); outer_bindings.push_back( arith::NormalizeIterMapToExpr(GetRef(outer_binding))); outer_iter_vars.push_back(outer_var); From e54add3d679da82d825531b874e5906f87d9da66 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 17:58:12 -0500 Subject: [PATCH 12/20] fix example --- python/tvm/tir/schedule/schedule.py | 126 +++++++++++++--------------- src/tir/ir/function.cc | 13 +-- src/tir/schedule/ir_comparator.h | 5 ++ src/tir/schedule/primitive.h | 3 +- 4 files changed, 73 insertions(+), 74 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d3a5c1ee919a..297bdb8e5143 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1859,25 +1859,16 @@ def before_tensorize( ) -> None: # body # with T.block("root") - for i_0, j_0 in T.grid(8, 8): - for i_1_init, j_1_init in T.grid(16, 16): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i_1_init, j_1_init]) - T.reads() - T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) - C[i_0 * 16 + vi, j_0 * 16 + vj] = T.float32(0) - for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i_1, j_1, k_1]) - T.reads( - C[i_0 * 16 + vi, j_0 * 16 + vj], - A[i_0 * 16 + vi, k_0 * 16 + vk], - B[j_0 * 16 + vj, k_0 * 16 + vk] - ) - T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) - C[i_0 * 16 + vi, j_0 * 16 + vj] = C[i_0 * 16 + vi, j_0 * 16 + vj] + \ - A[i_0 * 16 + vi, k_0 * 16 + vk] * \ - B[j_0 * 16 + vj, k_0 * 16 + vk] + for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16): + with T.block("update"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + vk = T.axis.reduce(128, k_0 * 16 + k_1) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] Declare and register the tensor intrinsic: @@ -1943,56 +1934,55 @@ def after_tensorize( B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: - for i_0, j_0 in T.grid(8, 8): - for i_1_init, j_1_init in T.grid(16, 16): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i_1_init, j_1_init]) - T.reads() - T.writes(C[i_0 * 16 + vi, j_0 * 16 + vj]) - C[i_0 * 16 + vi, j_0 * 16 + vj] = T.float32(0) - for k_0 in T.serial(8): - with T.block("update_o"): - vio = T.axis.spatial(1, 0) - vjo = T.axis.spatial(1, 0) - vko = T.axis.reduce(1, 0) - T.reads( - C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16], - A[i_0 * 16 : i_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], - B[j_0 * 16 : j_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], - ) - T.writes(C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16]) - A_1 = T.match_buffer( - A[i_0 * 16 : i_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], - [16, 16], - dtype="float32", - offset_factor=1, - ) - B_1 = T.match_buffer( - B[j_0 * 16 : j_0 * 16 + 16, k_0 * 16 : k_0 * 16 + 16], - [16, 16], - dtype="float32", - offset_factor=1, - ) - C_1 = T.match_buffer( - C[i_0 * 16 : i_0 * 16 + 16, j_0 * 16 : j_0 * 16 + 16], - [16, 16], - dtype="float32", - offset_factor=1, - ) - T.evaluate( - T.tvm_mma_sync( - C_1.data, - C_1.elem_offset // 256, - A_1.data, - A_1.elem_offset // 256, - B_1.data, - B_1.elem_offset // 256, - C_1.data, - C_1.elem_offset // 256, - dtype="handle", - ) + # body + # with T.block("root") + for i_0, j_0, k_0 in T.grid(8, 8, 8): + with T.block("update_o"): + vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) + T.reads( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + ) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer( + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + B_1 = T.match_buffer( + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + C_1 = T.match_buffer( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + with T.init(): + for i_1, j_1 in T.grid(16, 16): + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i_1, j_1]) + T.reads() + T.writes(C[vio * 16 + vi_init, vjo * 16 + vj_init]) + C[vio * 16 + vi_init, vjo * 16 + vj_init] = T.float32(0) + T.evaluate( + T.tvm_mma_sync( + C_1.data, + C_1.elem_offset // 256, + A_1.data, + A_1.elem_offset // 256, + B_1.data, + B_1.elem_offset // 256, + C_1.data, + C_1.elem_offset // 256, + dtype="handle", ) - + ) """ _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member self, block_or_loop, tensor_intrin diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 3b35cd45b8ba..1c34e34468b5 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -76,10 +76,14 @@ class TensorIntrinManager { TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { // Check the number of func var is equal - CHECK_EQ(desc->params.size(), impl->params.size()) << "ValueError: The number of parameters of the description and the implementation of the tensor intrinsic doesn't match."; + CHECK_EQ(desc->params.size(), impl->params.size()) + << "ValueError: The number of parameters of the description and the implementation of the " + "tensor intrinsic doesn't match."; for (size_t i = 0; i < desc->params.size(); i++) { - CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the tensor intrinsic should be handle only."; - CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of the tensor intrinsic should be handle only."; + CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the " + "tensor intrinsic should be handle only."; + CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of " + "the tensor intrinsic should be handle only."; } ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); @@ -99,8 +103,7 @@ void TensorIntrin::Register(String name, TensorIntrin intrin) { TensorIntrin TensorIntrin::Get(String name) { const TensorIntrinManager* manager = TensorIntrinManager::Global(); auto it = manager->reg.find(name); - CHECK(it != manager->reg.end()) - << "ValueError: TensorIntrin '" << name << "' is not registered"; + CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name << "' is not registered"; return manager->reg.at(name); } diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index b81cd6a4a6f1..156104965aef 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -19,6 +19,11 @@ #ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ #define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#include +#include +#include +#include + #include "./utils.h" namespace tvm { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 68c74413a219..2368411e6f09 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -393,7 +393,8 @@ TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); * \param block_or_loop_sref The block or loop to be tensorized. * \param intrin The tensor intrinsic. */ -TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, const TensorIntrin& intrin); +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrin); /******** Schedule: Annotation ********/ /*! From 2902b72db8b2ac596be71de0d149e49f1a8776de Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 18:09:39 -0500 Subject: [PATCH 13/20] lint --- python/tvm/tir/schedule/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 297bdb8e5143..96fa21f30020 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -21,7 +21,7 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, TensorIntrin +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod From 4fb00a5ac6f0eafeed3cf2615c691398f228c0d0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 18:23:12 -0500 Subject: [PATCH 14/20] lint --- include/tvm/tir/function.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index a85d19a1be76..1ab911b756df 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -216,7 +216,7 @@ class TensorIntrin : public ObjectRef { * \param desc The function to describe the computation. * \param impl The function of the implementation for the execution. */ - TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); + TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl); /*! * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked From 36d3a5e447c6e177fe586827ce0badf5805ee0bc Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 20 Jan 2022 19:45:26 -0500 Subject: [PATCH 15/20] lint --- src/tir/schedule/primitive/blockize_tensorize.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 70fbd6e2038d..1fa77f7f3d8b 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -564,7 +564,7 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, // Step 1: Blockize the subtree rooted at the given loop if needed StmtSRef block_sref{nullptr}; - if (const auto* loop = block_or_loop_sref->StmtAs()) { + if (block_or_loop_sref->StmtAs()) { block_sref = Blockize(self, block_or_loop_sref); } else { ICHECK(block_or_loop_sref->StmtAs()); @@ -608,13 +608,14 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, // add the detected base indices to each buffer access region of the tensor intrinsic Region old_region = buffer_region_map.at(buffer); const auto& indices_base = comparator.buffer_indices_.at(source); - int offset = indices_base.size() - old_region.size(); + int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); + ICHECK(offset >= 0); Region new_region; new_region.reserve(source->shape.size()); for (int i = 0; i < offset; i++) { new_region.push_back(Range::FromMinExtent(indices_base[i], 1)); } - for (int i = 0; i < old_region.size(); i++) { + for (int i = 0; i < static_cast(old_region.size()); i++) { new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent)); } match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region))); From c11babbcbba46ec22d00a7aff4cd979a08468355 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 24 Jan 2022 15:35:42 -0500 Subject: [PATCH 16/20] remove unused --- src/tir/schedule/ir_comparator.cc | 15 +++++++++- src/tir/schedule/ir_comparator.h | 1 + src/tir/schedule/transform.cc | 47 ------------------------------- src/tir/schedule/transform.h | 8 ------ 4 files changed, 15 insertions(+), 56 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 2d11411f6735..2ec1264a213d 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -196,6 +196,12 @@ bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& o return CompareBufferAccess(op, rhs); } +bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) + && VisitExpr(op->false_value, rhs->false_value); +} + bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { if (lhs.same_as(rhs)) return true; auto it = equal_map_.find(lhs); @@ -253,7 +259,14 @@ bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { } bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { - if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + if (!CompareBuffer(lhs->buffer, rhs->buffer)) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer; + EmitError(os.str()); + } + return false; + } int offset = static_cast(lhs->region.size()) - static_cast(rhs->region.size()); // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS if (offset < 0) return false; diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 156104965aef..359677d8852f 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -74,6 +74,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override; /*! \brief Map from RHS buffer to LHS buffer */ std::unordered_map rhs_buffer_map_; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 0165df5949df..ffb6b2d52628 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,52 +136,5 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } -/******** IR Substitution ********/ -class IRSubstituteInScope : public StmtExprMutator { - public: - explicit IRSubstituteInScope(std::function fmap) - : fmap_(std::move(fmap)) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = fmap_(op); - if (it.defined()) { - return it; - } else { - return GetRef(op); - } - } - - Stmt VisitStmt_(const BlockRealizeNode* op) final { - auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); }; - Array v = op->iter_values; - v.MutateByApply(fmutate); - PrimExpr pred = this->VisitExpr(op->predicate); - if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->iter_values = std::move(v); - n->predicate = std::move(analyzer.Simplify(pred)); - return Stmt(n); - } - } - - private: - const std::function fmap_; - arith::Analyzer analyzer; -}; - -Stmt SubstituteInScope(const Stmt& stmt, const Map& subst_map) { - auto fmap = [&](const VarNode* v) -> PrimExpr { - const auto& it = subst_map.find(GetRef(v)); - if (it != subst_map.end()) { - return (*it).second; - } else { - return NullValue(); - } - }; - return IRSubstituteInScope(std::move(fmap))(stmt); -} - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index b1ce52407baf..3932c4bdbd3d 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -104,14 +104,6 @@ Array ReplaceBuffer(Array match_buffers, c void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt); -/******** IR Substitution ********/ - -/*! - * \param var_map The mapping of var - * \return The converted stmt - */ -Stmt SubstituteInScope(const Stmt& stmt, const Map& subst_map); - } // namespace tir } // namespace tvm From b574f456ec5ef2aca5c5a0d6efe2b99089bb999b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 24 Jan 2022 20:46:26 -0500 Subject: [PATCH 17/20] trigger ci From 0e1e7cdffe16111d8c349e306e7ed32697c512c2 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 24 Jan 2022 22:10:19 -0800 Subject: [PATCH 18/20] clang-format --- src/tir/schedule/ir_comparator.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 2ec1264a213d..3e61e953a95b 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -198,8 +198,8 @@ bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& o bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) - && VisitExpr(op->false_value, rhs->false_value); + return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) && + VisitExpr(op->false_value, rhs->false_value); } bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { From 6beb550e146c41ba0533ae2a3540bf091898d335 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 13:33:31 -0500 Subject: [PATCH 19/20] fix --- src/tir/schedule/primitive/blockize_tensorize.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 1fa77f7f3d8b..bbeb9caaab9b 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -96,10 +96,8 @@ Array> TrivialSubspaceDivision(const Array& iter arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); for (size_t i = 0; i < bindings.size(); ++i) { - bool outer = UsesVar( - bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); - bool inner = UsesVar( - bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + bool outer = use_outer_loop_vars(bindings[i]); + bool inner = use_inner_loop_vars(bindings[i]); arith::IterMark iter_mark; if (bindings[i]->IsInstance()) { iter_mark = arith::IterMark( @@ -165,10 +163,10 @@ Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_r Map subst_map; for (size_t i = 0; i < init_block_iters.size(); i++) { IterVar new_iter_var = init_block_iters[i]; - auto* new_init_var_node = new_iter_var.CopyOnWrite(); Var old_var = new_iter_var->var; - new_init_var_node->var = old_var.copy_with_suffix("_init"); - subst_map.Set(old_var, new_iter_var->var); + Var new_var = old_var.copy_with_suffix("_init"); + new_iter_var.CopyOnWrite()->var = new_var; + subst_map.Set(old_var, new_var); init_block_iters.Set(i, std::move(new_iter_var)); } From dc2854b526751ede0304df7f284cd83432686a47 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 25 Jan 2022 20:09:16 -0500 Subject: [PATCH 20/20] rebase --- src/tir/schedule/analysis/analysis.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index ddb4d3dd0a1c..be5e55d4ec70 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1848,7 +1848,5 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } -======= ->>>>>>> cde8c476f (WIP) } // namespace tir } // namespace tvm