diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 22b4cd580e18..eb69c188abf3 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -350,6 +350,13 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer, DiagnosticContext diag_ctx); +/*! + * \brief Given an IterMapExpr, transform it to normal PrimExpr. + * \param expr The input IterMapExpr. + * \return The corresponding normal PrimExpr. + */ +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e482a18c4a5b..1ab911b756df 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,58 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Tensor intrinsics for tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc desc; + /*! \brief The function of the implementation for the execution. */ + PrimFunc impl; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("desc", &desc); + v->Visit("impl", &impl); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +/*! + * \brief Managed reference to TensorIntrinNode. + */ +class TensorIntrin : public ObjectRef { + public: + /*! + * \brief Constructor + * \param desc The function to describe the computation. + * \param impl The function of the implementation for the execution. + */ + TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl); + + /*! + * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked + * up with its name. + * \param name The name of the TensorIntrin to register + * \param intrin The TensorIntrin to register. + * \throws This method throws an exception if the TensorIntrin with the specified name already + * exists. + */ + TVM_DLL static void Register(String name, TensorIntrin intrin); + + /*! + * \brief Look up TensorIntrin by name. Raises an exception if not found. + * \param name The name of the TensorIntrin. + * \return The TensorIntrin with the specified name. + * \throws This method throws an exception if the TensorIntrin does not exist. + */ + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 43f2379a0b56..be06b44820cd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -473,6 +473,25 @@ class ScheduleNode : public runtime::Object { */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ + /*! + * \brief Convert the subtree rooted at a specific loop into a block. + * \param loop_rv the root of the subtree + * \return the new block + */ + virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param loop_rv The loop to be tensorized + * \param intrin Name of the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param block_rv The block to be tensorized + * \param intrin Name of the tensor intrinsic + */ + virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0; + /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 07ceb29ebf98..5854b9369c16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc +from .function import PrimFunc, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ecbcd837cb72..bcebab9ddc0a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -162,3 +162,51 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: return tvm._ffi.get_global_func("script.AsTVMScript")( self, tir_prefix, show_meta ) # type: ignore + + +@tvm._ffi.register_object("tir.TensorIntrin") +class TensorIntrin(Object): + """A tensor intrinsic. + + Parameters + ---------- + desc : PrimFunc + The function to describe the computation. + + impl : PrimFunc + The function of the implementation for the execution. + """ + + def __init__(self, desc, impl): + self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl) + + @staticmethod + def register(name: str, desc: PrimFunc, impl: PrimFunc): + """Register a tensor intrinsic with its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to register. + desc : PrimFunc + The function to describe the computation. + impl : PrimFunc + The function of the implementation for the execution. + """ + return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl)) # type: ignore + + @staticmethod + def get(name: str): + """Look up a tensor intrinsic by its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to look up. + + Returns + ------- + result : TensorIntrin + The TensorIntrin with the specified name. + """ + return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7d352f156a31..96fa21f30020 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1759,6 +1759,235 @@ def after_set_scope( ########## Schedule: Blockize & Tensorize ########## + @type_checked + def blockize(self, loop: LoopRV) -> BlockRV: + """Convert the subtree rooted at a specific loop into a block. + + Parameters + ---------- + loop : LoopRV + The root of the subtree. + + Returns + ------- + result : BlockRV + The new block. + + Examples + -------- + + Before blockize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + Create the schedule and do set_scope: + + .. code-block:: python + + sch = tir.Schedule(before_blockize) + B = sch.get_block("B") + _, _, i1, _ = sch.get_loops(B) + sch.blockize(i1) + print(sch.mod["main"].script()) + + After applying blockize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + )-> None: + for i_0, j_0 in T.grid(8, 8): + with T.block("B_o"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] \ + * T.float32(2) + + Note + ---- + blockize requires there is exactly one block under the given loop and the bindings of the + block are divisible by the subspace represented by the loops starting at the given loop. + """ + + return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member + + @type_checked + def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None: + """Tensorize the computation enclosed by loop with the tensor intrinsic. + + Parameters + ---------- + block_or_loop : Union[BlockRV, LoopRV] + The loop to be tensorized. + tensor_intrin : str + The tensor intrin or the name of the tensor intrin. + + Examples + -------- + + Before tensorize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_tensorize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16): + with T.block("update"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + vk = T.axis.reduce(128, k_0 * 16 + k_1) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + Declare and register the tensor intrinsic: + + .. code-block:: python + + @T.prim_func + def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + + @T.prim_func + def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) + + Create the schedule and do tensorize: + + .. code-block:: python + + sch = tir.Schedule(before_tensorize) + update = sch.get_block("update") + _, _, _, i1, _, _ = sch.get_loops(update) + sch.tensorize(i1, "test_mma_intrin") + print(sch.mod["main"].script()) + + After applying tensorize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_tensorize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0, k_0 in T.grid(8, 8, 8): + with T.block("update_o"): + vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) + T.reads( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + ) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer( + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + B_1 = T.match_buffer( + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + C_1 = T.match_buffer( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + with T.init(): + for i_1, j_1 in T.grid(16, 16): + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i_1, j_1]) + T.reads() + T.writes(C[vio * 16 + vi_init, vjo * 16 + vj_init]) + C[vio * 16 + vi_init, vjo * 16 + vj_init] = T.float32(0) + T.evaluate( + T.tvm_mma_sync( + C_1.data, + C_1.elem_offset // 256, + A_1.data, + A_1.elem_offset // 256, + B_1.data, + B_1.elem_offset // 256, + C_1.data, + C_1.elem_offset // 256, + dtype="handle", + ) + ) + """ + _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member + self, block_or_loop, tensor_intrin + ) + ########## Schedule: Annotation ########## @type_checked diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 55a1a5a1830e..3d30eef99d7d 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -511,7 +511,7 @@ Range IntSet::CoverRange(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); ICHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::FromMinExtent(s_int->min_value, + return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52ea1..1c34e34468b5 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -64,6 +64,51 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +class TensorIntrinManager { + public: + Map reg; + + static TensorIntrinManager* Global() { + static TensorIntrinManager* inst = new TensorIntrinManager(); + return inst; + } +}; + +TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { + // Check the number of func var is equal + CHECK_EQ(desc->params.size(), impl->params.size()) + << "ValueError: The number of parameters of the description and the implementation of the " + "tensor intrinsic doesn't match."; + for (size_t i = 0; i < desc->params.size(); i++) { + CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the " + "tensor intrinsic should be handle only."; + CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of " + "the tensor intrinsic should be handle only."; + } + ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); + + ObjectPtr n = make_object(); + n->desc = std::move(desc); + n->impl = std::move(impl); + data_ = std::move(n); +} + +void TensorIntrin::Register(String name, TensorIntrin intrin) { + TensorIntrinManager* manager = TensorIntrinManager::Global(); + CHECK_EQ(manager->reg.count(name), 0) + << "ValueError: TensorIntrin '" << name << "' has already been registered"; + manager->reg.Set(name, intrin); +} + +TensorIntrin TensorIntrin::Get(String name) { + const TensorIntrinManager* manager = TensorIntrinManager::Global(); + auto it = manager->reg.find(name); + CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name << "' is not registered"; + return manager->reg.at(name); +} + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // TODO(tvm-team) redirect to Text printer once we have a good text format. @@ -85,5 +130,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); +TVM_REGISTER_GLOBAL("tir.TensorIntrin") + .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9f8dc6dd2daf..fc63f305ff5e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -606,6 +606,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + /******** Schedule: Annotation ********/ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 96cb0f728835..5f108178a83b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -123,6 +123,9 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) override; + void Tensorize(const BlockRV& loop_rv, const String& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc new file mode 100644 index 000000000000..3e61e953a95b --- /dev/null +++ b/src/tir/schedule/ir_comparator.cc @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./ir_comparator.h" + +namespace tvm { + +namespace tir { + +/******** Tensorize Comparator ********/ + +class TensorIntrinMismatchError : public ScheduleError { + public: + explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt, + std::vector error_messages) + : lhs_mod_(std::move(lhs_mod)), + lhs_stmt_(std::move(lhs_stmt)), + rhs_stmt_(std::move(rhs_stmt)), + error_messages_(std::move(error_messages)) { + ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); + } + + String FastErrorString() const final { + return "ScheduleError: The stmt doesn't match the tensor intrin."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_; + for (const auto& msg : error_messages_) { + os << msg << std::endl; + } + return os.str(); + } + + IRModule mod() const final { return lhs_mod_; } + + Array LocationsOfInterest() const final { return {lhs_stmt_}; } + + private: + IRModule lhs_mod_; + Stmt lhs_stmt_; + Stmt rhs_stmt_; + std::vector error_messages_; +}; + +/* Override the dispatcher to make sure RHS is always valid */ +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + bool equal = n.same_as(other) || + ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other)); + if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { + throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_)); + } + return equal; +} + +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool equal = + n.same_as(other) || ((n->type_index() == other->type_index()) && n->dtype == other->dtype && + ExprComparator::VisitExpr(n, other)); + if (!equal && assert_mode_) { + std::ostringstream os; + os << "Expression mismatch: " << n << " vs " << other; + EmitError(os.str()); + } + return equal; +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (op->thread_binding.defined() != rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) { + return false; + } + if (op->kind != rhs->kind) return false; + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false; + return VisitStmt(op->body, rhs->body); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!is_scope_block) { + if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { + return false; + } + } + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!is_scope_block) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (op->dtype != rhs->dtype) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) && + VisitExpr(op->false_value, rhs->false_value); +} + +bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + // Otherwise remap lhs to rhs + equal_map_[lhs] = rhs; + analyzer_.Bind(lhs, rhs); + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret(map.begin(), map.end()); + sort(ret.begin(), ret.end()); + return ret; + }; + + std::vector> lhs_array = sort_map(lhs); + std::vector> rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Remap both buffer itself and buffer data, skip buffer shape + equal = + DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } + } + return equal; +} + +bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer; + EmitError(os.str()); + } + return false; + } + int offset = static_cast(lhs->region.size()) - static_cast(rhs->region.size()); + // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS + if (offset < 0) return false; + + auto it = buffer_indices_.find(lhs->buffer); + if (it == buffer_indices_.end()) { + // Update base indices for the buffer, this can only happen if it is visiting the scope block. + ICHECK(is_scope_block); + std::vector indices_base; + indices_base.reserve(lhs->region.size()); + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + indices_base.emplace_back(lhs->region[i]->min); + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // save base index + indices_base.emplace_back(lhs->region[i + offset]->min); + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + } + buffer_indices_.emplace(lhs->buffer, std::move(indices_base)); + } else { + // Check the base indices are consistent. + const std::vector& indices_base = it->second; + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false; + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { + return false; + } + } + } + return true; +} + +// Comparator for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + int offset = static_cast(lhs->indices.size()) - static_cast(rhs->indices.size()); + if (offset < 0) return false; + auto it = buffer_indices_.find(lhs->buffer); + ICHECK(it != buffer_indices_.end()); + const std::vector& indices_base = (*it).second; + ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); + for (size_t i = 0; i < rhs->indices.size(); i++) { + PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; + if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs " << rhs->indices[i]; + EmitError(os.str()); + } + return false; + } + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& rhs) { + return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type; +} + +void TensorizeComparator::EmitError(const std::string& error_message) { + error_messages_.push_back(error_message); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h new file mode 100644 index 000000000000..359677d8852f --- /dev/null +++ b/src/tir/schedule/ir_comparator.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ + +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace tir { + +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +/*! \brief Deep comparison to check if two IR ASTs are equivalent for tensorization*/ +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + /*! + * \brief Constructor of TensorizeComparator + * \param assert_mode Whether to raise an error if the two IR ASTs do not match. + * \param lhs_mod The IRModule of the LHS. This is used for error reporting. + */ + explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) + : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {} + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override; + + /*! \brief Map from RHS buffer to LHS buffer */ + std::unordered_map rhs_buffer_map_; + /*! \brief Base indices of the LHS buffer. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + + protected: + bool DefEqual(const Var& lhs, const Var& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); + void EmitError(const std::string& error_message); + + /*! \brief IRModule of the LHS stmt. */ + IRModule lhs_mod_; + /*! \brief Whether assertion mode is enabled. */ + bool assert_mode_; + /*! \brief Whether it is visiting the scope block (the outermost block). */ + bool is_scope_block = true; + /*! \brief The arithmetic analyzer. */ + arith::Analyzer analyzer_; + /*! \brief Additional error messages. Only used when assert_mode is true. */ + std::vector error_messages_; + // variable remap if any + std::unordered_map equal_map_; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f0b38af01b5f..2368411e6f09 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -378,6 +378,24 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer const String& storage_scope); /******** Schedule: Blockize & Tensorize ********/ + +/*! + * \brief Convert the subtree rooted at a specific loop into a block. + * \param self The state of the schedule + * \param loop_sref The root of the subtree + * \return The new block + */ +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); + +/*! + * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. + * \param self The state of the schedule + * \param block_or_loop_sref The block or loop to be tensorized. + * \param intrin The tensor intrinsic. + */ +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrin); + /******** Schedule: Annotation ********/ /*! * \brief Annotate a block/loop with a key value pair diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc new file mode 100644 index 000000000000..bbeb9caaab9b --- /dev/null +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -0,0 +1,698 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace + * represented by the outer loops. + */ +class SubspaceNotDivisibleError : public ScheduleError { + public: + explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) + : mod_(std::move(mod)), + scope_loop_(std::move(scope_loop)), + inner_block_(std::move(inner_block)) {} + + String FastErrorString() const final { + return "ScheduleError: The bindings of the inner block can not be blockized."; + } + + String DetailRenderTemplate() const final { + return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " + "starting at {1}."; + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + + private: + IRModule mod_; + For scope_loop_; + Block inner_block_; +}; + +/*! + * \brief Detect if bindings are a trivial case of the subspace division where we can divide the + * block iter bindings into two categories: + * 1. The binding covers no inner loop vars. + * 2. The binding covers only inner loop vars. + * + * The bindings are not required to be quasi-affine. + * + * \param iter_vars The input iterators + * \param bindings The values of iter_vars + * \param outer_loops Iterators outside the subspace. + * \param inner_loops Iterators of the subspace + * \param predicate The predicate constraint on the input iterators. + * \return The result of the subspace division. + */ +Array> TrivialSubspaceDivision(const Array& iter_vars, + const Array& bindings, + const Array& outer_iters, + const Array& inner_iters, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + Array> res; + std::unordered_set outer_loop_vars; + std::unordered_set inner_loop_vars; + + auto make_uses_var = [](const Array& vars) -> std::function { + std::unordered_set var_set; + var_set.reserve(vars.size()); + for (const Var& var : vars) { + var_set.insert(var.get()); + } + return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&var_set](const VarNode* var) { + return var_set.count(var); // + }); + }; + }; + auto use_outer_loop_vars = make_uses_var(outer_iters); + auto use_inner_loop_vars = make_uses_var(inner_iters); + arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); + + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = use_outer_loop_vars(bindings[i]); + bool inner = use_inner_loop_vars(bindings[i]); + arith::IterMark iter_mark; + if (bindings[i]->IsInstance()) { + iter_mark = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + if (outer && !inner) { + res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark}); + } else if (inner && !outer) { + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark}); + } else if (!outer && !inner) { + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/unit_iter_mark}); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +/*! + * \brief Generate the blockized init block. + * \param block The original block with init. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \return The subtree of the init block and its outer loops. + */ +Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, + const std::vector& inner_loops) { + Array init_block_iters; + Array init_bindings; + const Block& inner_block = inner_block_realize->block; + + // Step 1: Collect data-parallel block iters + for (size_t i = 0; i < inner_block->iter_vars.size(); i++) { + const IterVar& iter_var = inner_block->iter_vars[i]; + const PrimExpr& binding = inner_block_realize->iter_values[i]; + if (iter_var->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) { + init_block_iters.push_back(iter_var); + init_bindings.push_back(binding); + } + } + + // Step 2: Collect loops related to iters of the init block + std::vector init_loops; + for (const ForNode* inner_loop : inner_loops) { + for (const PrimExpr& init_binding : init_bindings) { + if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { + return var == tgt_var; + })) { + init_loops.push_back(inner_loop); + break; + } + } + } + + // Step 3: Create new block iters for the init block + Map subst_map; + for (size_t i = 0; i < init_block_iters.size(); i++) { + IterVar new_iter_var = init_block_iters[i]; + Var old_var = new_iter_var->var; + Var new_var = old_var.copy_with_suffix("_init"); + new_iter_var.CopyOnWrite()->var = new_var; + subst_map.Set(old_var, new_var); + init_block_iters.Set(i, std::move(new_iter_var)); + } + + // Step 4: Generate loop nests and the init block + Stmt new_init = BlockRealize( + /*iter_values=*/init_bindings, + /*predicate=*/inner_block_realize->predicate, + /*block=*/ + Block{/*iter_vars=*/init_block_iters, + /*reads=*/{}, + /*writes=*/block->writes, + /*name_hint=*/block->name_hint + "_init", + /*body=*/block->init.value(), + /*init=*/NullOpt}); + + // Step 5: Generate the parent loops for the init block + for (const ForNode* init_loop : init_loops) { + ObjectPtr new_loop = make_object(*init_loop); + new_loop->loop_var = init_loop->loop_var.copy_with_suffix(""); + subst_map.Set(init_loop->loop_var, new_loop->loop_var); + new_loop->body = std::move(new_init); + new_init = For(new_loop); + } + + // Step 6: Substitute with new loop variables and block iters to prevent duplication of + // variables in the outer block. + new_init = Substitute(new_init, subst_map); + + return new_init; +} + +/*! + * \brief A helper to collect the parent loops of the block. The loops are divided into two groups, + * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the + * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its + * successor loops. It is possible that 'outer_loops' is empty. + */ +class LoopSubspaceCollector { + public: + /*! + * \brief Collect the parent loops of the block and store the result in the corresponding fields. + * \param block_sref The sref to the target block. + * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop. + */ + void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { + bool inner = true; + for (StmtSRefNode* current_sref = block_sref->parent; + current_sref && current_sref->stmt->IsInstance(); + current_sref = current_sref->parent) { + const auto* current_loop = current_sref->StmtAs(); + ICHECK(current_loop); + if (inner) { + inner_loops.push_back(current_loop); + inner_loop_vars.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_loop_vars.push_back(current_loop->loop_var); + } + loop_var_domain.Set(current_loop->loop_var, + Range::FromMinExtent(current_loop->min, current_loop->extent)); + if (current_sref == loop_sref.get()) inner = false; + } + } + /*! \brief Outer loops which are ancestors of the separator. */ + std::vector outer_loops; + /*! \brief Inner loops which are the separator itself or its successors. */ + std::vector inner_loops; + /*! \brief Loop variables of the outer loops. */ + Array outer_loop_vars; + /*! \brief Loop variables of the inner loops. */ + Array inner_loop_vars; + /*! \brief Domain of the loop variables. */ + Map loop_var_domain; +}; + +/*! + * \brief Check the bindings of the block iters can be divided by a subspace collected by the + * collector. + * \param mod The current IR module. + * \param block_realize The block realize to be checked. + * \param collector The collector which has collected the loops of the block. + * \param analyzer The arithmetic analyzer. + * \return The result of the subspace division. + * \throws ScheduleError If the bindings are not divisible by the subspace. + */ +Array> CheckSubspaceDivisible(const IRModule& mod, + const BlockRealize& block_realize, + const LoopSubspaceCollector& collector, + arith::Analyzer* analyzer) { + const Block& block = block_realize->block; + DiagnosticContext diag_ctx(DiagnosticContext::Default(mod)); + + Array> division = + arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, + collector.inner_loop_vars, block_realize->predicate, + /*require_bijective=*/false, analyzer, diag_ctx); + + if (division.empty()) { + // If we can't do perfect subspace division, check if it is a trivial case of subspace division. + // In this case, we can still blockize. + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, + collector.outer_loop_vars, collector.inner_loop_vars, + block_realize->predicate); + } + if (division.empty()) { + throw SubspaceNotDivisibleError(mod, GetRef(collector.inner_loops.back()), block); + } + return division; +} + +/*! + * \brief The binding extractor to compute the bindings of the outer and the inner blocks after + * blockize. + */ +class BlockizedBindingExtractor { + public: + /*! + * \brief Extract bindings for blockize. + * \param iter_vars The iter vars of the original inner block. + * \param division The result of the subspace division. + */ + void ExtractBindings(const Array& iter_vars, + const Array>& division, arith::Analyzer* analyzer) { + ICHECK_EQ(iter_vars.size() + 1, division.size()); + for (size_t i = 0; i < iter_vars.size(); ++i) { + const IterVar& iter_var = iter_vars[i]; + arith::IterMark outer_mark = division[i][0]; + arith::IterMark inner_mark = division[i][1]; + const auto* outer_binding = + TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); + const auto* inner_binding = + TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); + + // After computing the subspace division, bindings[i] can be written as + // outer_binding * inner_binding->extent + inner_binding + // The outer block will have binding: iter_outer -> outer_binding + // The inner block will have binding: iter_inner -> inner_binding + // The iter in the original block will be substituted with base + iter_inner where + // base == iter_outer * iter_inner_extent + + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(iter_var); + } else { + // create iter var for the outer block + const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent), + /*var=*/iter_var->var.copy_with_suffix("_o"), + /*iter_type=*/iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(outer_var); + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + // create iter var for the inner block + IterVar new_iter = iter_var; + auto* new_iter_node = new_iter.CopyOnWrite(); + new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent); + inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom)); + analyzer->Bind(new_iter->var, new_iter->dom); + inner_iter_vars.push_back(new_iter); + inner_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + inner_iter_subst_map.Set(iter_var->var, base + new_iter->var); + } + } + } + Map inner_iter_subst_map; + /*! \brief Iters of the outer block. */ + Array outer_iter_vars; + /*! \brief Iters of the outer block. */ + Array inner_iter_vars; + /*! \brief Binding values of the outer block. */ + Array outer_bindings; + /*! \brief Binding values of the inner block. */ + Array inner_bindings; + /*! \brief The domain of the inner block iters. */ + Map inner_iter_dom_map; +}; + +/*! + * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with + * base + inner_iter and the expressions after substituion will be simplified if possible. + */ +class InnerIterReplacer : public StmtExprMutator { + public: + /*! + * \brief The constructor + * \param subst_map The substitution map of the inner block iters. + * \param analyzer The arithmetic analyzer. + * \param block_sref_reuse The map to save the block reuse information. + */ + InnerIterReplacer(Map subst_map, arith::Analyzer* analyzer, + Map* block_sref_reuse) + : subst_map_(std::move(subst_map)), + analyzer_(analyzer), + block_sref_reuse_(block_sref_reuse) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = subst_map_.find(GetRef(op)); + if (it != subst_map_.end()) { + return (*it).second; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr(const PrimExpr& op) final { + PrimExpr result = StmtExprMutator::VisitExpr(op); + if (!result.same_as(op)) { + return analyzer_->Simplify(result); + } + return result; + } + + Stmt VisitStmt_(const BlockNode* op) final { + Stmt result = StmtExprMutator::VisitStmt_(op); + if (!result.same_as(GetRef(op))) { + block_sref_reuse_->Set(GetRef(op), Downcast(result)); + } + return result; + } + + private: + Map subst_map_; + arith::Analyzer* analyzer_; + Map* block_sref_reuse_; +}; + +/*! + * \brief Compute the access region of the outer block by relaxing the inner loops. + * \param buffer_region The original buffer region. + * \param The range of the inner loops. + * \return The new buffer region. + */ +BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, + const Map& inner_iter_relaxed_range) { + Array new_region; + new_region.reserve(buffer_region->region.size()); + Array relaxed_int_set = + arith::EvalSet(buffer_region->region, inner_iter_relaxed_range); + ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size()); + for (size_t i = 0; i < buffer_region->region.size(); i++) { + Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]); + new_region.push_back(relaxed_int_set[i].CoverRange(max_range)); + } + return BufferRegion(buffer_region->buffer, std::move(new_region)); +} + +/*! + * \brief Generate the outer block after blockize. + * \param extractor The binding extractor which has extracted the blockized bindings. + * \param block The original inner block. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \param predicate The outer predicate of the subspace division. + * \return The block realize of the outer block after blockize. + */ +BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, + const Block& block, BlockRealize inner_block_realize, + const std::vector& inner_loops, + PrimExpr predicate) { + // Step 1: Generate the init block if needed + Optional new_init = NullOpt; + if (block->init.defined()) { + new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops); + } + + // Step 2: Compute the access regions of the outer block by relaxing the inner loops + Array new_reads = block->reads; + Array new_writes = block->writes; + + auto f_mutate = [&](const BufferRegion& buffer_region) { + return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map); + }; + new_reads.MutateByApply(f_mutate); + new_writes.MutateByApply(f_mutate); + + // Step 3: Generate the body of the outer block. The body of the outer block is the inner block + // realize and its surrounding loops. + Stmt outer_block_body = inner_block_realize; + for (const ForNode* loop : inner_loops) { + ObjectPtr new_loop = make_object(*loop); + new_loop->body = std::move(outer_block_body); + outer_block_body = For(new_loop); + } + + // Step 4: Generate the outer block and block realize. + return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings), + /*predicate=*/std::move(predicate), + /*block=*/ + Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), // + /*reads=*/std::move(new_reads), // + /*writes=*/std::move(new_writes), // + /*name_hint=*/block->name_hint + "_o", // + /*body=*/std::move(outer_block_body), // + /*init=*/std::move(new_init))); +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + arith::Analyzer analyzer; + + // Step 1: Check the loop has a single child BlockRealize on the sref tree. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); + Block block = block_realize->block; + StmtSRef block_sref = self->stmt2ref.at(block.get()); + + // Step 2: Collect loops inside and outside loop_sref. + LoopSubspaceCollector collector; + collector.Collect(block_sref, loop_sref); + + // Step 3: Calculate subspace division for the inner loops. + Array> division = + CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer); + + // Step 4: Generate bindings for the outer block and the inner block based on the result of + // the subspace division. + BlockizedBindingExtractor extractor; + extractor.ExtractBindings(block->iter_vars, division, &analyzer); + const PrimExpr& outer_pred = division.back()[0]->extent; + const PrimExpr& inner_pred = division.back()[1]->extent; + + // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace + // division + Map block_sref_reuse; + InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, + &block_sref_reuse); + Block new_block = Downcast(replacer(block)); + + // Step 6: Generate the inner block. + BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); + inner_block_realize->iter_values = extractor.inner_bindings; + inner_block_realize->predicate = inner_pred; + inner_block_realize->block = new_block; + BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); + inner_block->iter_vars = extractor.inner_iter_vars; + inner_block->init = NullOpt; + block_sref_reuse.Set(block, inner_block_realize->block); + + // Step 6: Generate the outer block. + BlockRealize outer_realize = + GenerateBlockizedOuterBlock(extractor, new_block, GetRef(inner_block_realize), + collector.inner_loops, outer_pred); + // Step 7: Do the actual replacement + self->Replace(loop_sref, outer_realize, block_sref_reuse); + + // Step 8: Update the cached flags + StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->block_info[scope_root].affine_binding = scope_block_affine_binding; + return outer_block_sref; +} + +/*! + * \brief Update the map from the buffers in the desc to the impl of the tensor + * intrinsic. + * \param intrinsic The tensor intrinsic. + * \param buffer_map The map to be updated. + */ +void RemapTensorIntrinBuffers( + const TensorIntrin& intrinsic, + std::unordered_map* buffer_map) { + ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size()); + for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) { + const Var& lhs_var = intrinsic->desc->params[i]; + const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; + const Var& rhs_var = intrinsic->impl->params[i]; + const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; + (*buffer_map)[rhs_buffer] = lhs_buffer; + } +} + +void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrinsic) { + /*! + * Check: + * - Check buffer binding, including type, alignment, shape and etc. + * - Check the sub AST is equal to the desc function. + * + * Mutate: + * - Blockize the sub AST (please refer blockize for details) + * - Bind buffers + * - Mutate the impl of the tensor intrinsic by replacing its buffers with new + * buffers created via match buffer region. + * - Replace the sub tree with the mutated function. + */ + const BlockRealize& desc_block_realize = Downcast(intrinsic->desc->body); + const BlockRealize& impl_block_realize = Downcast(intrinsic->impl->body); + Block impl_block = impl_block_realize->block; + + // Step 1: Blockize the subtree rooted at the given loop if needed + StmtSRef block_sref{nullptr}; + if (block_or_loop_sref->StmtAs()) { + block_sref = Blockize(self, block_or_loop_sref); + } else { + ICHECK(block_or_loop_sref->StmtAs()); + block_sref = block_or_loop_sref; + } + const BlockRealize& block_realize = GetBlockRealize(self, block_sref); + + // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence + // between buffers in the block and the desc. + TensorizeComparator comparator(self->mod, /*assert_mode=*/true); + comparator.VisitStmt(block_realize, desc_block_realize); + + // Step 3: Find the correspondence between buffers in the current AST and the impl of + // the tensor intrinsic + // Step 3.1: Map from intrinsic func buffer to desc func buffer + std::unordered_map intrin_buffer_map; + RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); + // Step 3.2: Map form intrinsic func buffer to current AST buffer + std::unordered_map buffer_map; + for (const auto& pair : intrin_buffer_map) { + auto it = comparator.rhs_buffer_map_.find(pair.second); + ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second; + buffer_map[pair.first] = it->second; + } + + // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor + // intrin to make them subregions of the buffer in the original IR. + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + for (const BufferRegion& read : impl_block->reads) { + buffer_region_map.emplace(read->buffer, read->region); + } + for (const BufferRegion& write : impl_block->writes) { + buffer_region_map.emplace(write->buffer, write->region); + } + Array match_buffer_regions; + match_buffer_regions.reserve(intrinsic->impl->params.size()); + for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) { + const auto& param = intrinsic->impl->params[i]; + const auto& buffer = intrinsic->impl->buffer_map.at(param); + const auto& source = buffer_map.at(buffer); + // add the detected base indices to each buffer access region of the tensor intrinsic + Region old_region = buffer_region_map.at(buffer); + const auto& indices_base = comparator.buffer_indices_.at(source); + int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); + ICHECK(offset >= 0); + Region new_region; + new_region.reserve(source->shape.size()); + for (int i = 0; i < offset; i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i], 1)); + } + for (int i = 0; i < static_cast(old_region.size()); i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent)); + } + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region))); + } + + // Step 5: Replace the subtree in the original IR with the tensor intrin impl. + ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); + new_block_ptr->body = impl_block->body; + ICHECK(new_block_ptr->match_buffers.empty()); + new_block_ptr->match_buffers = std::move(match_buffer_regions); + Block new_block(new_block_ptr); + + self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); + + // Step 6: Update the cached flags. + StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + self->UpdateScopeBlockInfo(static_cast(scope_root->stmt)->body); +} + +/******** InstructionKind Registration ********/ + +struct BlockizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Blockize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Blockize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("blockize"); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct TensorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Tensorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin) { + if (const auto* block = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(block), intrin); + } else if (const auto* loop = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(loop), intrin); + } else { + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); + } + } + + static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { + PythonAPICall py("tensorize"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("intrin", intrin); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 6e33862c07ca..b466843f9459 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -185,6 +185,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") + .set_body_method(&ScheduleNode::Blockize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") + .set_body_typed([](Schedule self, ObjectRef rv, String intrin) { + if (const auto* block_rv = rv.as()) { + self->Tensorize(GetRef(block_rv), intrin); + } else if (const auto* loop_rv = rv.as()) { + self->Tensorize(GetRef(loop_rv), intrin); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + } + }); + /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..3a37f81b5dbc 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -201,9 +201,7 @@ class BlockInfoCollector : private StmtVisitor { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` Array child_block_srefs = std::move(block_frames_.back()); - BlockInfo& info = - self_->block_info.emplace(scope_root, BlockInfo(BlockScope(child_block_srefs))) - .first->second; + BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index da7a2641b162..1e2e57eb6eca 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -356,6 +356,37 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{intrin}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(block_rv, intrin); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{intrin}, + /*outputs=*/{})); +} + /******** Schedule: Annotation ********/ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b35f1b6e17bb..3a88e869d309 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -87,6 +87,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) final; + void Tensorize(const BlockRV& block_rv, const String& intrin) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py new file mode 100644 index 000000000000..b4a16a8231b8 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +from tvm.script import tir as T +from tvm import tir +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@T.prim_func +def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + +@T.prim_func +def single_elementwise_blockized1( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + with T.block("blockized_B"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + T.reads(A[0:128, 0:128]) + T.writes(B[0:128, 0:128]) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + +@T.prim_func +def single_elementwise_blockized2( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + for i in T.serial(128): + with T.block("blockized_B"): + vi = T.axis.spatial(128, i) + vjo = T.axis.spatial(1, 0) + T.reads(A[vi, 0:128]) + T.writes(B[vi, 0:128]) + for j in T.serial(128): + with T.block("B"): + vj = T.axis.remap("S", [j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + +@T.prim_func +def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def two_elementwise_blockized( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"] +) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0 in T.grid(8, 8): + with T.block("blockized_B"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2) + with T.block("blockized_C"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for ax0, ax1 in T.grid(16, 16): + with T.block("C"): + vi, vj = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[vio * 16 + vi, vjo * 16 + vj]) + T.writes(C[vio * 16 + vi, vjo * 16 + vj]) + C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1) + + +@T.prim_func +def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = B[vi] + A[vi, vk] + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +def test_blockize_outer(): + func = single_elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + x, y = s.get_loops(B) + s.blockize(x) + print(s.mod['main'].script()) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_inner(): + func = single_elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + x, y = s.get_loops(B) + s.blockize(y) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_two_elementwise_blockize_reverse_compute_at(): + func = two_elementwise + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.reverse_compute_at(C, yo) + s.blockize(s.get_loops(C)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_two_elementwise_blockize_compute_at(): + func = two_elementwise + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(C) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.compute_at(B, yo) + s.blockize(s.get_loops(B)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_init_loops(): + s = tir.Schedule(rowsum, debug_mask="all") + k, _ = s.get_loops(s.get_block("B")) + s.blockize(k) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + verify_trace_roundtrip(sch=s, mod=rowsum) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py new file mode 100644 index 000000000000..401a39f379b7 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@T.prim_func +def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, ()) + + with T.block("root"): + T.reads(C[()], A[0 : 4], B[0 : 4]) + T.writes(C[()]) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.remap("R", [i]) + C[()] = C[()] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), offset_factor=1) + B = T.match_buffer(b, (4,), offset_factor=1) + C = T.match_buffer(c, (), offset_factor=1) + + with T.block("root"): + T.reads(C[()], A[0 : 4], B[0 : 4]) + T.writes(C[()]) + T.evaluate( + T.call_extern( + "vec4add", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) + + with T.block("root"): + T.reads( + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], + ) + T.writes(C[0 : 16, 0 : 16]) + for i, j in T.grid(16, 16): + with T.block("update"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0] + + +@T.prim_func +def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) + + with T.block("root"): + T.reads( + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], + ) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.call_extern( + "outer_product", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def matmul( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], +) -> None: + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + + for i_outer, j_outer in T.grid(8, 8): + for i_inner_init, j_inner_init in T.grid(16, 16): + with T.block("init"): + vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) + vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) + C[vi_init, vj_init] = T.float32(0) + for k_outer in T.grid(8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads( + [ + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ] + ) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + [16, 16], + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def batch_matmul( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def tensorized_batch_matmul_mma( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n in range(0, 16): + for i, j, k in T.grid(8, 8, 8): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + T.reads( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ) + T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_dot_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i, j, k_0 in T.grid(16, 128, 128, 32): + with T.block("blockized_update"): + vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0]) + T.reads( + C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko * 4 : vko * 4 + 4] + ) + T.writes(C[vn, vi, vj]) + A_1 = T.match_buffer( + A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + B_1 = T.match_buffer( + B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32", offset_factor=1) + T.evaluate( + T.call_extern( + "vec4add", + C_1.data, + C_1.elem_offset, + A_1.data, + A_1.elem_offset, + B_1.data, + B_1.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_outer_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i_0, j_0, k in T.grid(16, 8, 8, 128): + with T.block("blockized_update"): + vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k]) + T.reads( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vn, vio * 16 : vio * 16 + 16, vk], + B[vn, vjo * 16 : vjo * 16 + 16, vk], + ) + T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1) + B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1 + ) + C_1 = T.match_buffer( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32", offset_factor=1 + ) + T.evaluate( + T.call_extern("outer_product", C_1.data, C_1.elem_offset, A_1.data, A_1.elem_offset, + B_1.data, B_1.elem_offset, dtype="int32" + ) + ) + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) +tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin) +tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin) + + +def test_tensorize_matmul(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_batch_matmul(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + _, i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=batch_matmul) + + +def test_tensorize_dot_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, _, _, k = s.get_loops(C) + _, ki = s.split(k, factors=[None, 4]) + s.tensorize(ki, "test_dot_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_outer_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, i, j, k = s.get_loops(C) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + s.reorder(io, jo, k, ii, ji) + s.tensorize(ii, "test_outer_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))