diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 9eead8d5ec31..a1dd4a412eec 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -183,6 +183,10 @@ class DatabaseNode : public runtime::Object { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. */ explicit DatabaseNode(String mod_eq_name = "structural"); @@ -274,6 +278,10 @@ class PyDatabaseNode : public DatabaseNode { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. */ explicit PyDatabaseNode(String mod_eq_name = "structural"); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 3bc30e09c74a..1b018512146f 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -245,9 +245,12 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Auto bind loops around the block to BlockIdx and ThreadIdx * \param max_threadblocks The maximum number of threadblock on GPU * \param thread_extents Candidates of thread axis extent. + * \param max_threads_per_block The maximum number of threads per block, if it is known + * when this schedule rule is created. * \return The schedule rule created */ - TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents); + TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents, + int max_threads_per_block = -1); /*! * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 04acdc9d4a75..0dd679e047e0 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -30,6 +30,7 @@ search_strategy, space_generator, tir_integration, + trace_apply, ) from .builder import Builder from .cost_model import CostModel diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index f81d8913c18a..102a13b90d98 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -40,6 +40,10 @@ class JSONDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ path_workload: str diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 96b9bb5a0112..34a6a141970a 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -33,6 +33,10 @@ class MemoryDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 7a0b433996c5..c7d175cb79d3 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -39,6 +39,10 @@ class ScheduleFnDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 089f6e412e20..5e77181d32bf 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -143,6 +143,10 @@ def extract_tasks( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- @@ -288,6 +292,10 @@ def tune_relay( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- diff --git a/python/tvm/meta_schedule/trace_apply.py b/python/tvm/meta_schedule/trace_apply.py new file mode 100644 index 000000000000..c621cf973af2 --- /dev/null +++ b/python/tvm/meta_schedule/trace_apply.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Specialized applications of trace""" +from ..tir.schedule import Schedule, Trace +from ..target import Target +from . import _ffi_api + + +def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target: Target) -> None: + """Apply the trace from a TIR module whose anchor block is the same but fused elemewise op + blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add + subgraph to other subgraphs having the same conv2d workload, for example. We call such trace + an "anchor trace". Those blocks that are not scheduled by the given anchor trace will be either + inlined or parallelized. + + Parameters + ---------- + sch : Schedule + The target schedule + anchor_trace: Trace + The trace generated for other TIR module having the same anchor block + target : tvm.target.Target + The compilation target + """ + _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 07021eac3998..a69c8f126272 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -76,6 +76,10 @@ def tune_tasks( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index d7db182f9d20..662dd10ec068 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -25,8 +25,9 @@ # add all floating point and integer datatypes to the module for _dtype in ["float", "uint", "int"]: for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32"]: + for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: from . import ty _name = _dtype + _size + _lanes - globals()[_name] = getattr(ty, _name) + if hasattr(ty, _name): + globals()[_name] = getattr(ty, _name) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index b8323dd4a167..b17b571e88e7 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -202,7 +202,7 @@ def __getitem__(self, args): # add all floating point and integer datatypes to the module for _dtype in ["float", "uint", "int"]: for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32"]: + for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: _name = _dtype + _size + _lanes globals()[_name] = ConcreteType(_name) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 90c585ac8ce1..e1c0019d9bf0 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -122,3 +122,21 @@ def get_auto_tensorize_mapping_info( intrinsics. """ return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore + + +def has_block(sch: Schedule, block_name: str) -> bool: + """Query if the given block name exists in the module associated with the provided schedule. + + Parameters + ---------- + sch : Schedule + The schedule + block_name : str + The name of the block to query + + Returns + ------- + yes/no: bool + True if the given block exists in the schedule. + """ + return _ffi_api.HasBlock(sch, block_name) # type: ignore diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index caa7da170bd6..f9ffe82aa271 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -73,11 +74,34 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { } }; +// The NDArray-ignoring variant of structural equal / hash is used for the module equality +// on the extracted anchor blocks. +class ModuleEqualityAnchorBlock : public ModuleEquality { + size_t Hash(IRModule mod) const { + auto anchor_block = tir::FindAnchorBlock(mod); + if (anchor_block) { + return SHashHandlerIgnoreNDArray().Hash(GetRef(anchor_block), false); + } + return ModuleEqualityIgnoreNDArray().Hash(mod); + } + bool Equal(IRModule lhs, IRModule rhs) const { + auto anchor_block_lhs = tir::FindAnchorBlock(lhs); + auto anchor_block_rhs = tir::FindAnchorBlock(rhs); + if (anchor_block_lhs && anchor_block_rhs) { + return SEqualHandlerIgnoreNDArray().Equal(GetRef(anchor_block_lhs), + GetRef(anchor_block_rhs), false); + } + return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); + } +}; + std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { if (mod_eq_name == "structural") { return std::make_unique(); } else if (mod_eq_name == "ignore-ndarray") { return std::make_unique(); + } else if (mod_eq_name == "anchor-block") { + return std::make_unique(); } LOG(FATAL) << "Unknown module equality " << mod_eq_name; return nullptr; diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index 8c99b563551b..ba5877471e2c 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -42,6 +42,10 @@ class ModuleEquality { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. * \return An owning pointer to the created instance */ static std::unique_ptr Create(const std::string& mod_eq_name); diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 7af1418d8f3e..4d16a6d4d65d 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -208,10 +208,11 @@ Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl return {sch}; } -ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents) { +ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents, + int max_threads_per_block) { ObjectPtr n = make_object(); n->max_threadblocks_ = max_threadblocks; - n->max_threads_per_block_ = -1; + n->max_threads_per_block_ = max_threads_per_block; n->thread_extents_ = std::move(thread_extents); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 8333833bfafa..bd492d03eac6 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -53,14 +53,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( Array ScheduleRule::DefaultLLVM() { return { - ScheduleRule::AutoInline( - /*into_producer=*/false, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/true, - /*require_injective=*/true, - /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + GetDefaultAutoInline("llvm"), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -98,14 +91,7 @@ Array ScheduleRule::DefaultCUDA() { Map{{"req", String("must")}, {"levels", Array{3}}, // {"scope", String("local")}}), - ScheduleRule::AutoInline( - /*into_producer=*/true, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/false, - /*require_injective=*/false, - /*require_ordered=*/false, - /*disallow_op=*/Array{}), + GetDefaultAutoInline("cuda"), ScheduleRule::CrossThreadReduction( /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 53107bafb2c0..bcc0673e5924 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -45,12 +45,11 @@ String GetRuleKindFromTarget(const Target& target) { } return "cuda"; } - if (target->kind->name == "rocm") { - return "cuda"; - } - if (target->kind->name == "vulkan") { + + if (IsGPUTarget(target->kind->name)) { return "cuda"; } + LOG(FATAL) << "Unsupported target: " << target; throw; } diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc new file mode 100644 index 000000000000..70b6451d3546 --- /dev/null +++ b/src/meta_schedule/trace_apply.cc @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "trace_apply.h" + +#include +#include + +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace meta_schedule { + +using namespace tir; + +// Returns true if b1 is an ancestor of b2 +bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { + if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { + return true; + } + for (auto prod : sch->GetProducers(b2)) { + if (IsAncestor(b1, prod, sch)) return true; + } + return false; +} + +// Inline or reverse inline spatial blocks after the anchor block +void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { + static auto kind_get_block = InstructionKind::Get("GetBlock"); + // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace. + // We record such block names to avoid inlining them here. + std::unordered_set get_block_names; + for (const auto& inst : anchor_trace->insts) { + if (inst->kind.same_as(kind_get_block)) { + auto block_name = Downcast(inst->attrs[0]); + ICHECK(block_name.defined()); + get_block_names.insert(block_name); + } + } + + auto anchor_block = FindAnchorBlock(sch->mod()); + + auto inline_rule = GetDefaultAutoInline(target->kind->name); + + for (auto name : GetBlockNames(sch->mod())) { + auto block = sch->GetBlock(name); + if (anchor_block) { + auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint); + if (IsAncestor(block, anchor_block_rv, sch)) continue; + } + // Spatial blocks which are not referenced in the anchor trace will be inlined here. + if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { + inline_rule->Apply(sch, block); + } + } +} + +// Apply instructions from the anchor trace to the target schedule, and returns blocks +// that remain unscheduled. +std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { + static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); + static auto kind_get_block = InstructionKind::Get("GetBlock"); + static auto kind_compute_inline = InstructionKind::Get("ComputeInline"); + static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline"); + + const auto block_names_orig = GetBlockNames(sch->mod()); + const auto sch_orig = sch->Copy(); + + std::unordered_map rv_map; + // Blocks and loops that appear in the anchor trace but are not part of the target schedule. + std::unordered_set foreign_blocks; + std::unordered_set foreign_loops; + + // Instructions in the anchor trace can be applied only if all inputs are part of the target + // schedule. + auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { + for (auto input : inst->inputs) { + if (!input.defined()) continue; + if ((input->IsInstance() && foreign_blocks.count(Downcast(input))) || + (input->IsInstance() && foreign_loops.count(Downcast(input)))) { + return false; + } + } + return true; + }; + + for (const auto& inst : anchor_trace->insts) { + if (!is_inst_applicable(inst)) { + // If we find an instruction that is not applicable, its outputs are recorded as "foreign" + // to the target schedule. + for (auto output : inst->outputs) { + if (output->IsInstance()) { + foreign_blocks.insert(Downcast(output)); + } else if (output->IsInstance()) { + foreign_loops.insert(Downcast(output)); + } + } + continue; + } + + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + + if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast(inst->attrs[0]))) { + // The anchor trace does get_block on a block that is not part of the target schedule. + auto block = Downcast(inst->outputs[0]); + foreign_blocks.insert(block); + continue; + } else if (inst->kind.same_as(kind_reverse_compute_inline)) { + // The anchor trace does reverse_compute_inline on a block, but the block with the same name + // in the target schedule cannot be reverse compute inline-ed. + // In such cases, it should be possible to apply compute_inline instead. + auto block = Downcast(inputs[0]); + auto block_sref = sch->GetSRef(block); + if (!CanReverseComputeInline(sch->state(), block_sref)) { + ICHECK(CanComputeInline(sch->state(), block_sref)); + sch->ComputeInline(block); + continue; + } + } else if (inst->kind.same_as(kind_compute_inline)) { + // Similar to the reverse_compute_inline case above. + auto block = Downcast(inputs[0]); + auto block_sref = sch->GetSRef(block); + if (!CanComputeInline(sch->state(), block_sref)) { + ICHECK(CanReverseComputeInline(sch->state(), block_sref)); + sch->ReverseComputeInline(block); + continue; + } + } + + Optional decision = anchor_trace->GetDecision(inst); + Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); + + if (inst->kind.same_as(kind_get_child_blocks)) { + // We want to allow a trace generated for a single conv2d block to be applied to + // conv2d -> elemwise blocks, where two conv2d are the same workload. + // GetChildBlocks returns a different number of blocks for the two cases above, which + // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == + // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" + // outputs matches with the "old" outputs, and truncating the new outputs accordingly. + ICHECK(inst->outputs.size() <= outputs.size()); + TranslateAddOutputRVs( + inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); + } else { + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } + } + + auto is_scheduled = [=](const std::string& block_name) { + auto loops = sch->GetLoops(sch->GetBlock(block_name)); + auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); + if (loops.size() != loops_orig.size()) { + return true; + } + for (size_t i = 0; i < loops.size(); ++i) { + auto loop = sch->Get(loops[i]); + auto loop_orig = sch_orig->Get(loops_orig[i]); + if (loop->kind != loop_orig->kind) { + return true; + } + } + return false; + }; + + const auto block_names_now = GetBlockNames(sch->mod()); + std::vector unscheduled_blocks; + + for (auto name : block_names_orig) { + if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) { + unscheduled_blocks.push_back(sch->GetBlock(name)); + } + } + + return unscheduled_blocks; +} + +void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) { + InlinePostBlocks(sch, anchor_trace, target); + + auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); + ICHECK(unscheduled_blocks.size() <= 1) + << "All blocks should have been scheduled or only one (fused) spatial block can remain " + "unscheduled at this point."; + + if (unscheduled_blocks.empty()) { + // All blocks have already been scheduled. + return; + } + + auto last_block = unscheduled_blocks[0]; + auto last_block_producers = sch->GetProducers(last_block); + + if (last_block_producers.size() == 1 && IsSpatial(sch->GetSRef(last_block_producers[0]))) { + // Inline into the cache write stage + sch->ReverseComputeInline(last_block); + } else if (target->kind->name == "llvm" || target->kind->name == "hexagon") { + sch->Parallel(sch->Fuse(sch->GetLoops(last_block))); + } else if (IsGPUTarget(target->kind->name)) { + auto max_threads_per_block = target->GetAttr("max_threads_per_block"); + ICHECK(max_threads_per_block.defined()) + << "ValueError: missing attribute `max_threads_per_block` in the target"; + + auto auto_bind_rule = + ScheduleRule::AutoBind(/*max_threadblocks=*/256, + /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}, + max_threads_per_block.value()->value); + auto_bind_rule->Apply(sch, last_block); + } +} + +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") + .set_body_typed(ScheduleUsingAnchorTrace); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/trace_apply.h b/src/meta_schedule/trace_apply.h new file mode 100644 index 000000000000..9a9068ab914f --- /dev/null +++ b/src/meta_schedule/trace_apply.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_TRACE_APPLY_H_ +#define TVM_META_SCHEDULE_TRACE_APPLY_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Apply the trace from a TIR module whose anchor block is the same but fused elemewise + * op blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add + * subgraph to other subgraphs having the same conv2d workload, for example. We call such trace + * an "anchor trace". Those blocks that are not scheduled by the given anchor trace will be either + * inlined or parallelized. + * \param sch The schedule to apply the anchor trace. + * \param anchor_trace The trace tuned on other subgraph with the same anchor-block workload. + * \param target The target information needed for inlining and parallelization. + */ +void ScheduleUsingAnchorTrace(tir::Schedule sch, const tir::Trace& anchor_trace, + const tvm::Target& target); + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_TRACE_APPLY_H_ diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 824cfcd6aa5c..7240fa418839 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -44,6 +44,7 @@ #include #include +#include #include #include @@ -502,6 +503,41 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { } } +/*! \brief Returns true if the given target is one of the supported gpu targets. */ +inline bool IsGPUTarget(const std::string& target_name) { + static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal"}; + return gpu_targets.count(target_name); +} + +/*! + * \brief Create an AutoInline schedule rule for the given target. + * \param target_name The name of the target ("llvm", "cuda", etc.) + * \return The AutoInline schedule rule for the given target. + */ +inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { + if (target_name == "llvm" || target_name == "hexagon") { + return ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}); + } else if (IsGPUTarget(target_name)) { + return ScheduleRule::AutoInline( + /*into_producer=*/true, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/false, + /*require_injective=*/false, + /*require_ordered=*/false, + /*disallow_op=*/Array{}); + } + LOG(FATAL) << "Unsupported target " << target_name; + return ScheduleRule(nullptr); +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 430b551a3b9e..7e66dafe16f5 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -22,6 +22,8 @@ #include #include +#include + #include "../../meta_schedule/module_equality.h" #include "../../te/operation/create_primfunc.h" #include "./te_compiler_cache.h" @@ -31,6 +33,25 @@ namespace tvm { namespace relay { namespace backend { +class OpCounter : public ExprVisitor { + public: + static size_t GetOpCount(relay::Function func) { + OpCounter counter; + counter(func->body); + return counter.count; + } + + private: + void VisitExpr_(const CallNode* call) final { + if (call->op->IsInstance()) { + ++count; + } + ExprVisitor::VisitExpr_(call); + } + + size_t count{0}; +}; + Array ExtractTask(IRModule mod, Target target, Map params, String mod_eq_name) { @@ -52,33 +73,59 @@ Array ExtractTask(IRModule mod, Target target, std::unordered_map cache( /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq)); - PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &tir_converter](const Expr& exp) { + std::vector> lower_results; + + PostOrderVisit(mod->Lookup("main"), [&lower_results, &target, &tir_converter](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); if (Optional f = tir_converter(inputs_outputs, constants)) { IRModule tir_mod = PrimFuncToIRModule(f.value()); - - auto it = cache.find(tir_mod); - if (it != cache.end()) { - it->second->weight += 1; - return; - } - - // Note that the cache is key-ed on the tir mod, rather than the relay mod - IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); - ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1); - tasks.push_back(task); - cache.emplace(tir_mod, task); + lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod)); } } }); + + std::vector indices(lower_results.size()); + std::iota(indices.begin(), indices.end(), 0); + + if (mod_eq_name == "anchor-block") { + std::vector op_counts(lower_results.size()); + for (size_t i = 0; i < op_counts.size(); ++i) { + op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i])); + } + + // When anchor-block based equality is used, tuning tasks "nn_conv2d_add_nn_relu" and + // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. Thus, one of + // them will be filtered by the cache below. + // + // To make sure that we tune "nn_conv2d_add_nn_relu" and not "nn_conv2d_add_add_nn_relu", + // we sort the TE lowering results based on the number of relay ops. This way, + // "nn_conv2d_add_nn_relu" will be added to the cache first, and "nn_conv2d_add_add_nn_relu" + // will be filtered. + std::sort(indices.begin(), indices.end(), + [&op_counts](int i1, int i2) { return op_counts[i1] < op_counts[i2]; }); + } + + for (auto i : indices) { + const auto& [fused_name, relay_func, tir_mod] = lower_results[i]; + auto it = cache.find(tir_mod); + if (it != cache.end()) { + it->second->weight += 1; + continue; + } + // Note that the cache is key-ed on the tir mod, rather than the relay mod + IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); + ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1); + tasks.push_back(task); + cache.emplace(tir_mod, task); + } + // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); NameSupply name_supply = NameSupply(""); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e7326ed5dd4d..c97efb565d9d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -44,6 +44,7 @@ #include #include +#include #include #include #include @@ -52,6 +53,8 @@ #include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" +#include "../src/meta_schedule/module_equality.h" +#include "../src/meta_schedule/trace_apply.h" #include "../transforms/meta_schedule_layout_rewrite.h" #include "utils.h" @@ -461,7 +464,9 @@ class AllocateConstReplaceConstant : public StmtExprMutator { // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public ExprVisitor { public: - explicit ScheduleBuilder(Target target) : target_(target) { + explicit ScheduleBuilder(Target target) + : target_(target), + mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray")) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); if (backend::IsMetaScheduleEnabled()) { @@ -614,9 +619,19 @@ class ScheduleBuilder : public ExprVisitor { MetaScheduleLayoutRewriter::LayoutQueuePush(index_map); } } + Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); - record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + + if (!mod_eq_structural_->Equal(query_mod, opt_record.value()->workload->mod)) { + // When the database lookup succeeds while structural equality check fails, + // it implies that the anchor block based equality has been used during tuning. + // The trace in the record cannot directly be applied to this query module. + meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, target_); + } else { + record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + } + IRModule mod = sch->mod(); ICHECK_EQ(mod->functions.size(), 1); mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)( @@ -698,6 +713,7 @@ class ScheduleBuilder : public ExprVisitor { int anchor_op_pattern_{0}; bool use_auto_scheduler_; Optional database_; + std::unique_ptr mod_eq_structural_; }; /*! diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index a95f55357f2d..ef350004ad52 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -354,8 +354,11 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More // information about this limitation can be found here: // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc +// See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(32768)) .add_attr_option("thread_warp_size", Integer(16)) .add_attr_option("max_function_args", Integer(31)) .set_default_keys({"metal", "gpu"}); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d8b4f31f4c1b..64cc8013d716 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2059,5 +2059,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }); +TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index adadb46852cc..2c86c2df2d25 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -297,7 +297,12 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) * \param stage The stage to be inserted * \return A SeqStmt, the result after insertion */ -SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { +Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { + if (const auto* alloc = stmt.as()) { + auto seq_stmt = InsertCacheStage(alloc->body, pos, stage); + return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt, + alloc->annotations, alloc->span); + } if (const auto* seq_stmt = stmt.as()) { ObjectPtr result = make_object(*seq_stmt); result->seq.insert(result->seq.begin() + pos, stage); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c289309acc2d..bcc8b7facbc9 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -31,7 +31,9 @@ #include #include +#include #include +#include #include #include "../../arith/pattern_match.h" @@ -442,6 +444,50 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { } } +/******** Utilities for retrieving information about blocks ********/ + +/*! \brief Returns the names of the blocks in the provided module. */ +inline std::unordered_set GetBlockNames(const IRModule& mod) { + struct BlockNameCollector : public tir::StmtVisitor { + void VisitStmt_(const tir::BlockNode* block) override { + block_names.insert(block->name_hint); + StmtVisitor::VisitStmt(block->body); + } + std::unordered_set block_names; + }; + + auto prim_func = tir::FindEntryFunc(mod, nullptr); + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; +} + +/*! \brief Query if the given block name exists in the module associated with the schedule */ +inline bool HasBlock(const Schedule& sch, const std::string& block_name) { + auto block_names = GetBlockNames(sch->mod()); + return block_names.count(block_name); +} + +/******** Utilites for trace application ********/ + +/*! + * \brief Translate the input objects using the provided substitution map. + * \param inputs The input objects. + * \param rv_map The substitution map for variables. + * \return The transformed objects. + */ +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& rv_map); + +/*! + * \brief Update the variable substitution map according to the new outputs. + * \param old_outputs The previous outputs of a schedule instruction. + * \param new_outputs The new outputs of the same schedule instruction. + * \param rv_map The substitution map for variables. + */ +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* rv_map); + } // namespace tir } // namespace tvm diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index a541c25f3cbc..addbb052a2da 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -116,9 +116,11 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): postprocs=postprocs, mutator_probs={}, ), - # Without this, the same workloads with different constant weights - # are treated as distinct tuning tasks. - module_equality="ignore-ndarray", + # This enables anchor-block tuning, where different subgraphs + # with the same anchor block workload will be identified as equal. + # It reduces the number of conv2d tuning tasks in the int8 resnet50 model + # from 36 to 23, with negligible performance difference. + module_equality="anchor-block", ) return ms.relay_integration.compile_relay( diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 9a1c9e8dc7f5..c689a15c56b2 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -107,6 +107,41 @@ def test_meta_schedule_integration_extract_from_resnet(): assert t.task_name in expected_task_names, t.task_name +@requires_torch +def test_task_extraction_anchor_block(): + mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) + extracted_tasks = ms.relay_integration.extract_tasks( + mod, target="llvm", params=params, module_equality="anchor-block" + ) + + # Note that there is no task from residual blocks + expected_task_names = [ + "fused_" + s + for s in [ + "nn_max_pool2d", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + "layout_transform", + "layout_transform_reshape_squeeze", + ] + ] + + assert len(extracted_tasks) == len(expected_task_names) + for t in extracted_tasks: + assert t.task_name in expected_task_names, t.task_name + + @requires_torch def test_meta_schedule_integration_extract_from_bert_base(): pytest.importorskip( @@ -673,5 +708,115 @@ def test_module_equality_ignore_ndarray(): np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) +def _test_anchor_tuning(target): + data_shape = (128, 128) + weight_shape1 = (128, 128) + weight_shape2 = (128, 128) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32") + weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32") + dense1 = relay.nn.dense(data, weight1) + dense2 = relay.nn.dense(dense1 + relay.const(1.0, dtype="float32"), weight2) + mod = tvm.IRModule.from_expr(dense2 - data + relay.const(1.0, dtype="float32")) + + weight1_np = np.random.randn(*weight_shape1).astype("float32") + weight2_np = np.random.randn(*weight_shape2).astype("float32") + + data_np = np.random.randn(*data_shape).astype("float32") + params = {"weight1": weight1_np, "weight2": weight2_np} + + module_equality = "anchor-block" + + extracted_tasks = ms.relay_integration.extract_tasks( + mod, target, params, module_equality=module_equality + ) + + assert len(extracted_tasks) == 1 + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + module_equality=module_equality, + ) + lib = ms.relay_integration.compile_relay(database, mod, target, params) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + np.testing.assert_allclose(ref, out, atol=1e-3) + + +def test_anchor_tuning_cpu(): + _test_anchor_tuning("llvm --num-cores=4") + + +def test_anchor_tuning_cpu_link_params(): + data_shape = (128, 128) + weight_shape1 = (128, 128) + weight_shape2 = (128, 128) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32") + weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32") + dense1 = relay.nn.dense(data, weight1) + dense2 = relay.nn.dense(dense1, weight2) + mod = tvm.IRModule.from_expr(dense2 + relay.const(1.0, dtype="float32")) + + weight1_np = np.random.randn(*weight_shape1).astype("float32") + weight2_np = np.random.randn(*weight_shape2).astype("float32") + + data_np = np.random.randn(*data_shape).astype("float32") + params = {"weight1": weight1_np, "weight2": weight2_np} + + module_equality = "anchor-block" + target = "llvm --num-cores=4" + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + module_equality=module_equality, + ) + lib = ms.relay_integration.compile_relay(database, mod, target, params) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + np.testing.assert_allclose(ref, out, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py new file mode 100644 index 000000000000..6ff21c72c9ea --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -0,0 +1,2745 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +import tvm.testing +import tvm.meta_schedule as ms +from tvm.script import tir as T +from tvm.tir import Schedule, floormod, floordiv +from tvm.tir.tensor_intrin.cuda import * +from tvm.target import Target +from tvm.target.codegen import llvm_lookup_intrinsic_id + + +# fmt: off +@tvm.script.ir_module +class Dense: + @T.prim_func + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_matmul_NT: T.Buffer[(128, 128), "float32"], + ) -> None: + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("T_matmul_NT"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(p0[i, k], p1[j, k]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders": []}) + with T.init(): + T_matmul_NT[i, j] = T.float32(0) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] + + +@tvm.script.ir_module +class DenseAdd: + @T.prim_func + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") + compile_engine_const = T.alloc_buffer([], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("T_matmul_NT"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(p0[i, k], p1[j, k]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders": []}) + with T.init(): + T_matmul_NT[i, j] = T.float32(0) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(1) + for i0, i1 in T.grid(128, 128): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_matmul_NT[ax0, ax1], compile_engine_const[()]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + compile_engine_const[()] + + +@tvm.script.ir_module +class DenseAdd_scheduled_cpu: + @T.prim_func + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT_global = T.alloc_buffer([128, 128], dtype="float32") + p1_global = T.alloc_buffer([2, 128, 64], dtype="float32") + for ax0, ax1 in T.grid(128, 128): + with T.block("p1_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(p1[v0, v1]) + T.writes(p1_global[v0 // 64, v1, v0 % 64]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": 1}) + p1_global[v0 // 64, v1, v0 % 64] = p1[v0, v1] + for i0_0_i1_0_fused_fused in T.parallel(4): + for i0_1, i1_1 in T.grid(8, 1): + for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 1, 2): + for i1_3_fused_init in T.vectorized(64): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused // 2 * 64 + + i0_1 * 8 + + i0_2_init * 2 + + i0_3_init, + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused % 2 * 64 + + i1_1 * 64 + + i1_2_init * 64 + + i1_3_fused_init, + ) + T.reads() + T.writes(T_matmul_NT_global[i, j]) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.tiling_structure": "SSRSRS", + } + ) + T_matmul_NT_global[i, j] = T.float32(0) + for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(128, 4, 1, 1, 2): + for i1_3_fused in T.vectorized(64): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial( + 128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2 * 2 + i0_3 + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2 * 64 + i1_3_fused, + ) + k = T.axis.reduce(128, i2_0 + i2_1) + T.reads( + T_matmul_NT_global[i, j], p0[i, k], p1_global[j // 64, k, j % 64] + ) + T.writes(T_matmul_NT_global[i, j]) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.tiling_structure": "SSRSRS", + } + ) + T_matmul_NT_global[i, j] = ( + T_matmul_NT_global[i, j] + p0[i, k] * p1_global[j // 64, k, j % 64] + ) + for ax0 in T.serial(64): + for ax1_fused in T.vectorized(64): + with T.block("T_matmul_NT_global"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + ax1_fused) + T.reads(T_matmul_NT_global[v0, v1]) + T.writes(T_add[v0, v1]) + T_add[v0, v1] = T_matmul_NT_global[v0, v1] + T.float32(1) + + +@tvm.script.ir_module +class DenseAdd_cpu_no_write_cache: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") + p1_global = T.alloc_buffer([8, 4, 16, 32], dtype="float32") + for ax0, ax1 in T.grid(128, 128): + with T.block("p1_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(p1[v0, v1]) + T.writes(p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":1}) + p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32] = p1[v0, v1] + for i0_0_i1_0_i0_1_i1_1_fused in T.parallel(16, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 4, 2): + for i1_3_fused_init in T.vectorized(32): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2_init * 2 + i0_3_init) + j = T.axis.spatial(128, i1_2_init * 32 + i1_3_fused_init) + T.reads() + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT[i, j] = T.float32(0) + for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(8, 4, 4, 16, 2): + for i1_3_fused in T.vectorized(32): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2 * 2 + i0_3) + j = T.axis.spatial(128, i1_2 * 32 + i1_3_fused) + k = T.axis.reduce(128, i2_0 * 16 + i2_1) + T.reads(T_matmul_NT[i, j], p0[i, k], p1_global[k // 16, j // 32, k % 16, j % 32]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1_global[k // 16, j // 32, k % 16, j % 32] + for i0_i1_fused in T.parallel(16384): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0_i1_fused // 128) + ax1 = T.axis.spatial(128, i0_i1_fused % 128) + T.reads(T_matmul_NT[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + T.float32(1) + + +@tvm.script.ir_module +class DenseAdd_scheduled_gpu: + @T.prim_func + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT_local = T.alloc_buffer([128, 128], dtype="float32", scope="local") + p0_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + p1_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding( + 32, + thread="blockIdx.x", + annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}, + ): + for i0_1_i1_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(1, 4, 1, 1): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + + i0_2_i1_2_fused // 8 + + i0_3_init + + i0_4_init, + ) + j = T.axis.spatial( + 128, + i1_4_init + + i0_0_i1_0_fused % 4 * 32 + + i0_2_i1_2_fused % 8 * 4 + + i1_3_init, + ) + T.reads() + T.writes(T_matmul_NT_local[i, j]) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.thread_extent_high_inclusive": 256, + "meta_schedule.thread_extent_low_inclusive": 16, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + T_matmul_NT_local[i, j] = T.float32(0) + for i2_0 in T.serial(32): + for ax0_ax1_fused_0 in T.serial(1): + for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(2): + with T.block("p0_shared"): + T.where( + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 2 + + ax0_ax1_fused_2 + < 64 + ) + v0 = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 2 + + ax0_ax1_fused_2 + ) + // 4, + ) + v1 = T.axis.spatial( + 128, + i2_0 * 4 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 2 + + ax0_ax1_fused_2 + ) + % 4, + ) + T.reads(p0[v0, v1]) + T.writes(p0_shared[v0, v1]) + p0_shared[v0, v1] = p0[v0, v1] + for ax0_ax1_fused_0 in T.serial(1): + for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(4): + with T.block("p1_shared"): + T.where( + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 4 + + ax0_ax1_fused_2 + < 128 + ) + v0 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 4 * 32 + + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 4 + + ax0_ax1_fused_2 + ) + // 4, + ) + v1 = T.axis.spatial( + 128, + i2_0 * 4 + + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 4 + + ax0_ax1_fused_2 + ) + % 4, + ) + T.reads(p1[v0, v1]) + T.writes(p1_shared[v0, v1]) + p1_shared[v0, v1] = p1[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(1, 1, 4, 4, 1, 1): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3 + i0_4, + ) + j = T.axis.spatial( + 128, + i1_4 + + i0_0_i1_0_fused % 4 * 32 + + i0_2_i1_2_fused % 8 * 4 + + i1_3, + ) + k = T.axis.reduce(128, i2_0 * 4 + i2_1 * 4 + i2_2) + T.reads(T_matmul_NT_local[i, j], p0_shared[i, k], p1_shared[j, k]) + T.writes(T_matmul_NT_local[i, j]) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.thread_extent_high_inclusive": 256, + "meta_schedule.thread_extent_low_inclusive": 16, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + T_matmul_NT_local[i, j] = ( + T_matmul_NT_local[i, j] + p0_shared[i, k] * p1_shared[j, k] + ) + for ax0, ax1 in T.grid(1, 4): + with T.block("T_matmul_NT_local"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + ax0 + ) + v1 = T.axis.spatial( + 128, i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + ax1 + ) + T.reads(T_matmul_NT_local[v0, v1]) + T.writes(T_add[v0, v1]) + T_add[v0, v1] = T_matmul_NT_local[v0, v1] + T.float32(1) + + +@tvm.script.ir_module +class Conv2dInt8: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], compute: T.Buffer[(16, 56, 56, 256), "int32"]) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_multiply = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_add_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_right_shift = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_cast_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_2 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + T_cast_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_cast[ax0, ax1, ax2, ax3]) + T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_right_shift"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) + T.writes(T_right_shift[ax0, ax1, ax2, ax3]) + T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_right_shift[ax0, ax1, ax2, ax3]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3]) + T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3]) + T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3]) + T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") + for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): + with T.block("compute_1"): + i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) + T.writes(compute[i0_7, i1_7, i2_7, i3_7]) + compute[i0_7, i1_7, i2_7, i3_7] = T.q_multiply_shift(T_subtract_1[i0_7, i1_7, i2_7, i3_7], 1963325822, 31, 1, dtype="int32") + + +@tvm.script.ir_module +class Conv2dInt8_target: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_multiply = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_add_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_right_shift = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_cast_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_2 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + T_cast_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_4 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_cast[ax0, ax1, ax2, ax3]) + T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_right_shift"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) + T.writes(T_right_shift[ax0, ax1, ax2, ax3]) + T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_right_shift[ax0, ax1, ax2, ax3]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3]) + T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3]) + T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3]) + T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") + for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): + with T.block("compute_1"): + i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) + T.writes(compute_2[i0_7, i1_7, i2_7, i3_7]) + compute_2[i0_7, i1_7, i2_7, i3_7] = T.q_multiply_shift(T_subtract_1[i0_7, i1_7, i2_7, i3_7], 1098990753, 31, 1, dtype="int32") + for i0_8, i1_8, i2_8, i3_8 in T.grid(16, 56, 56, 256): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_8, i1_8, i2_8, i3_8]) + T.reads(compute_2[ax0, ax1, ax2, ax3], p9[ax0, ax1, ax2, ax3]) + T.writes(T_add_3[ax0, ax1, ax2, ax3]) + T_add_3[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] + p9[ax0, ax1, ax2, ax3] + for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 56, 56, 256): + with T.block("compute_2"): + i0_10, i1_10, i2_10, i3_10 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) + T.reads(T_add_3[i0_10, i1_10, i2_10, i3_10]) + T.writes(compute_3[i0_10, i1_10, i2_10, i3_10]) + compute_3[i0_10, i1_10, i2_10, i3_10] = T.max(T.min(T_add_3[i0_10, i1_10, i2_10, i3_10], 255), 0) + for i0_11, i1_11, i2_11, i3_11 in T.grid(16, 56, 56, 256): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_11, i1_11, i2_11, i3_11]) + T.reads(compute_3[ax0, ax1, ax2, ax3]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3]) + T_cast_4[ax0, ax1, ax2, ax3] = T.cast(compute_3[ax0, ax1, ax2, ax3], "uint8") + for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): + with T.block("compute_3"): + i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) + T.reads(T_cast_4[i0_13, i1_13, i2_13, i3_13]) + T.writes(compute[i0_13, i1_13, i2_13, i3_13]) + compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(T_cast_4[i0_13, i1_13, i2_13, i3_13], T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_tensorcore_scheduled: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + a0 = T.var("int32") + a1 = T.var("int32") + b0 = T.var("int32") + b1 = T.var("int32") + c0 = T.var("int32") + c1 = T.var("int32") + d0 = T.var("int32") + d0_1 = T.var("int32") + d0_2 = T.var("int32") + d0_3 = T.var("int32") + d1 = T.var("int32") + d1_1 = T.var("int32") + d1_2 = T.var("int32") + d1_3 = T.var("int32") + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + # body + # with T.block("root") + conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") + p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0_0, ax1_0 in T.grid(1, 1): + for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): + with T.block("conv2d_nhwc_o_init"): + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) + v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[d1, d0], scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // d1 // 16 * (d1 // 16) + C.elem_offset % d1 // 16, T.float32(0), dtype="handle")) + for ax4_0_0 in T.serial(2): + for ax0_ax1_fused_0 in T.serial(16): + for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(16): + with T.block("pad_temp_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) + T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(pad_temp_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]}) + pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): + with T.block("p1_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32) + v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(p1[v2, v0, v1, v3]) + T.writes(p1_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]}) + p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): + for ax0_0_1, ax1_0_1 in T.grid(1, 2): + with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) + T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[s1, s0], scope="shared", offset_factor=16) + C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_1, d0_1], scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // d1_1 // 16 * (d1_1 // 16) + C_1.elem_offset % d1_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): + with T.block("p1_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[s1_1, s0_1], scope="shared", offset_factor=16) + C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_2, d0_2], scope="wmma.matrix_b", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // d1_2 // 16 * (d1_2 // 16) + C_2.elem_offset % d1_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "col_major", dtype="handle")) + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): + with T.block("conv2d_nhwc_o_update"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[a1, a0], scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[b1, b0], scope="wmma.matrix_b", offset_factor=16) + C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[c1, c0], scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, A_2.data, A_2.elem_offset // a1 // 16 * (a1 // 16) + A_2.elem_offset % a1 // 16, B.data, B.elem_offset // b1 // 16 * (b1 // 16) + B.elem_offset % b1 // 16, C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, dtype="handle")) + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[d1_3, d0_3], scope="wmma.accumulator", offset_factor=16) + C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[s1_2, s0_2], scope="shared", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // d1_3 // 16 * (d1_3 // 16) + A_3.elem_offset % d1_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle")) + for ax0, ax1_0 in T.grid(128, 2): + for ax1_1 in T.thread_binding(16, thread="threadIdx.x"): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + ax0) + v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + ax1_0 * 16 + ax1_1) + T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[0, 0, 0, v1], p5[0, 0, 0, v1], p6[0, 0, 0, v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(p7[()] + T.cast(T.shift_right(T.cast(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], "int64") * p4[0, 0, 0, v1] + p5[0, 0, 0, v1], p6[0, 0, 0, v1], dtype="int64"), "int32"), 255), 0), "uint8"), "int32") - p8[0], 1098990753, 31, 1, dtype="int32") + p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 255), 0), "uint8"), T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_NCHWc: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "int32"], compute: T.Buffer[(1, 128, 7, 7, 16), "uint8"]) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="float32") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_1 = T.alloc_buffer([], dtype="float32") + T_add_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + T_cast_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_subtract = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_2 = T.alloc_buffer([], dtype="float32") + T_add_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_5 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(0.94537687301635742) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): + with T.block("conv2d_NCHWc_int8"): + n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) + T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) + T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) + T_cast[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) + T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) + T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] + with T.block("compile_engine_const_1"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_1[()]) + compile_engine_const_1[()] = T.float32(54.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) + T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) + T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) + T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1]) + T.writes(compute_1[i0_1, i1_1, i2_1, i3_1, i4_1]) + compute_1[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_2[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4], p4[0]) + T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) + T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_3[ax0, ax1, ax2, ax3, ax4] - p4[0] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) + T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_2"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_2[()]) + compile_engine_const_2[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) + T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) + T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T_add_3[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] + p5[ax0, ax1, ax2, ax3, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2]) + T.writes(compute_2[i0_2, i1_2, i2_2, i3_2, i4_2]) + compute_2[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) + T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) + T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_2[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_2"): + i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) + T.reads(T_cast_5[i0_5, i1_5, i2_5, i3_5, i4_5]) + T.writes(compute[i0_5, i1_5, i2_5, i3_5, i4_5]) + compute[i0_5, i1_5, i2_5, i3_5, i4_5] = T.max(T.min(T_cast_5[i0_5, i1_5, i2_5, i3_5, i4_5], T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_NCHWc_target: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "uint8"], T_cast: T.Buffer[(1, 128, 7, 7, 16), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="float32") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_1 = T.alloc_buffer([], dtype="float32") + T_add_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + T_cast_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_subtract = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_2 = T.alloc_buffer([], dtype="float32") + T_add_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_5 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compile_engine_const_3 = T.alloc_buffer([], dtype="float32") + T_cast_6 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_4 = T.alloc_buffer([], dtype="float32") + T_add_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_7 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(0.95489668846130371) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): + with T.block("conv2d_NCHWc_int8"): + n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) + T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) + T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) + T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_1[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) + T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) + T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast_1[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] + with T.block("compile_engine_const_1"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_1[()]) + compile_engine_const_1[()] = T.float32(65.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) + T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) + T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1, i4_1]) + compute[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compute[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(compute[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) + T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_3[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p4[0]) + T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) + T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] - p4[0] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) + T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_2"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_2[()]) + compile_engine_const_2[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) + T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) + T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") + with T.block("compile_engine_const_3"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_3[()]) + compile_engine_const_3[()] = T.float32(0.71245479583740234) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(p5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_6[ax0, ax1, ax2, ax3, ax4]) + T_cast_6[ax0, ax1, ax2, ax3, ax4] = T.cast(p5[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const_3[()], T_cast_6[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) + T_multiply_2[ax0, ax1, ax2, ax3, ax4] = compile_engine_const_3[()] * T_cast_6[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_4"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_4[()]) + compile_engine_const_4[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3, ax4], compile_engine_const_4[()]) + T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T_add_3[ax0, ax1, ax2, ax3, ax4] = T_multiply_2[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_4[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_2[ax0, ax1, ax2, ax3, ax4]) + T_floor_2[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_6"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_7[ax0, ax1, ax2, ax3, ax4]) + T_cast_7[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_2[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_5[ax0, ax1, ax2, ax3, ax4], T_cast_7[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_add_4[ax0, ax1, ax2, ax3, ax4]) + T_add_4[ax0, ax1, ax2, ax3, ax4] = T_cast_5[ax0, ax1, ax2, ax3, ax4] + T_cast_7[ax0, ax1, ax2, ax3, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2, i4_2]) + compute_1[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_7"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_8[ax0, ax1, ax2, ax3, ax4]) + T_cast_8[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_2"): + i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) + T.reads(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5]) + T.writes(compute_2[i0_5, i1_5, i2_5, i3_5, i4_5]) + compute_2[i0_5, i1_5, i2_5, i3_5, i4_5] = T.max(T.min(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5], T.uint8(255)), T.uint8(0)) + for i0_6, i1_6, i2_6, i3_6, i4_6 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_8"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_6, i1_6, i2_6, i3_6, i4_6]) + T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) + T_cast[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_2[ax0, ax1, ax2, ax3, ax4], "int32") + + +def get_conv2d_vnni_mod(intrin_id): + @tvm.script.ir_module + class Conv2dInt8_NCHWc_scheduled: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "uint8"], T_cast: T.Buffer[(1, 128, 7, 7, 16), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + for i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused in T.parallel(128, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1): + for i5_0, i6_0 in T.grid(1, 1): + for i1_2_init, i2_2_init, i3_2_init, i1_3_init, i2_3_init, i3_3_init in T.grid(1, 1, 1, 1, 1, 7): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + for i4_1 in T.vectorized(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_1 + i8_0) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1) + B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1) + C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x64: T.int8x64 = B[0, 0:64] + B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + C[0:16] = C[0:16] + T.call_llvm_pure_intrin(intrin_id, T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): + for ax4_fused in T.vectorized(16): + with T.block("T_cast_8"): + ax0_1 = T.axis.spatial(1, ax0) + ax1_1 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32 + ax1) + ax2_1 = T.axis.spatial(7, i2_1 + ax2) + ax3_1, ax4 = T.axis.remap("SS", [ax3, ax4_fused]) + T.reads(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4], p2[ax0_1, ax1_1, 0, 0, ax4], p3[ax0_1, ax1_1, 0, 0, ax4], p4[0], p5[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T.writes(T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = T.cast(T.max(T.min(T.cast(T.max(T.min(T.cast(T.floor(T.float32(0.95489668846130371) * (T.cast(T.cast(T.max(T.min(T.cast(T.floor(T.cast(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4] + p2[ax0_1, ax1_1, 0, 0, ax4], "float32") * p3[ax0_1, ax1_1, 0, 0, ax4] + T.float32(65.5), dtype="float32"), "int32"), 255), 0), "uint8"), "float32") - p4[0]) + T.float32(0.5), dtype="float32"), "int32") + T.cast(T.floor(T.float32(0.71245479583740234) * T.cast(p5[ax0_1, ax1_1, ax2_1, ax3_1, ax4], "float32") + T.float32(0.5), dtype="float32"), "int32"), 255), 0), "uint8"), T.uint8(255)), T.uint8(0)), "int32") + + return Conv2dInt8_NCHWc_scheduled + + +@tvm.script.ir_module +class Conv2dWinogradAddRelu: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + B = T.alloc_buffer([6, 6], dtype="float32") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + A = T.alloc_buffer([6, 4], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + conv2d_winograd = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + T.block_attr({"schedule_rule":"None"}) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") + for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) + T.writes(input_tile[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] + for i0, i1 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): + with T.block("data_pack"): + eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + with T.init(): + data_pack[eps, nu, p, ci] = T.float32(0) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] + for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): + with T.block("bgemm"): + eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) + T.writes(bgemm[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + bgemm[eps, nu, p, co] = T.float32(0) + bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] + for i0, i1 in T.grid(6, 4): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): + with T.block("inverse"): + vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("conv2d_winograd"): + n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add[ax0, ax1, ax2, ax3], T.float32(0)) + + +@tvm.script.ir_module +class Conv2dWinogradAddResidualRelu: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], p3: T.Buffer[(1, 56, 56, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + B = T.alloc_buffer([6, 6], dtype="float32") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + A = T.alloc_buffer([6, 4], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + conv2d_winograd = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add_1 = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + T.block_attr({"schedule_rule":"None"}) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") + for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) + T.writes(input_tile[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] + for i0, i1 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): + with T.block("data_pack"): + eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + with T.init(): + data_pack[eps, nu, p, ci] = T.float32(0) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] + for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): + with T.block("bgemm"): + eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) + T.writes(bgemm[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + bgemm[eps, nu, p, co] = T.float32(0) + bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] + for i0, i1 in T.grid(6, 4): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): + with T.block("inverse"): + vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("conv2d_winograd"): + n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3], p3[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_add[ax0, ax1, ax2, ax3] + p3[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add_1[ax0, ax1, ax2, ax3], T.float32(0)) + + +@tvm.script.ir_module +class Conv2dWinogradAddResidualRelu_scheduled: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], p3: T.Buffer[(1, 56, 56, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + input_tile_local = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + bgemm_local = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") + data_pack_shared = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="shared") + p1_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(98, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("input_tile"): + eps, nu = T.axis.remap("SS", [ax0, ax1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8 + ax2) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) + T.reads(p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci]) + T.writes(input_tile_local[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile_local[eps, nu, p, ci] = T.if_then_else(1 <= p % 196 // 14 * 4 + eps and p % 196 // 14 * 4 + eps < 57 and 1 <= p % 14 * 4 + nu and p % 14 * 4 + nu < 57, p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci], T.float32(0), dtype="float32") + for i0 in T.unroll(6): + for i1 in T.unroll(6): + with T.block("data_pack_init"): + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + T.reads() + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + data_pack[eps, nu, p, ci] = T.float32(0) + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack_update"): + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(data_pack[eps, nu, p, ci], input_tile_local[r_a, r_b, p, ci]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile_local[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(168, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(48, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(1, 1, 14, 1, 1, 1, 1, 1): + with T.block("bgemm_init"): + eps = T.axis.spatial(6, i0_4_init + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3_init) + nu = T.axis.spatial(6, i1_4_init + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3_init) + p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3_init + i2_4_init) + co = T.axis.spatial(64, i3_4_init + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3_init) + T.reads() + T.writes(bgemm_local[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + bgemm_local[eps, nu, p, co] = T.float32(0) + for i4_0 in T.serial(2): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(28): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 896) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) + v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 896 // 32) + v3 = T.axis.spatial(64, i4_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): + with T.block("p1_shared"): + v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 512) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) + v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 512 // 32) + v3 = T.axis.spatial(64, i4_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_shared[v0, v1, v2, v3]) + p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(2, 1, 1, 14, 1, 16, 1, 1, 1, 1): + with T.block("bgemm_update"): + eps = T.axis.spatial(6, i0_4 + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3) + nu = T.axis.spatial(6, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3) + p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3 + i2_4) + co = T.axis.spatial(64, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3) + ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 16 + i4_2) + T.reads(bgemm_local[eps, nu, p, co], data_pack_shared[eps, nu, p, ci], p1_shared[eps, nu, co, ci]) + T.writes(bgemm_local[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + bgemm_local[eps, nu, p, co] = bgemm_local[eps, nu, p, co] + data_pack_shared[eps, nu, p, ci] * p1_shared[eps, nu, co, ci] + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): + with T.block("bgemm_local"): + v0 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + ax0) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + ax1) + v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + ax2) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + ax3) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(25, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(512, thread="threadIdx.x"): + for i0 in T.unroll(4): + for i1 in T.unroll(4): + with T.block("inverse_init"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) + co = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 448 // 224 * 32 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32) + T.reads() + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + inverse[vh, vw, p, co] = T.float32(0) + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("inverse_update"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) + co = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 448 // 224 * 32 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(inverse[vh, vw, p, co], bgemm[r_a, r_b, p, co]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_i1_i2_i3_fused_0 in T.thread_binding(1568, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i0_i1_i2_i3_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("conv2d_winograd"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) // 3584) + w = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 3584 // 64) + co = T.axis.spatial(64, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 64) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co], p2[n, 0, 0, co], p3[n, h, w, co]) + T.writes(T_relu[n, h, w, co]) + T_relu[n, h, w, co] = T.max(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + p2[n, 0, 0, co] + p3[n, h, w, co], T.float32(0)) + + +# fmt: on +def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): + anchor_sch = Schedule(anchor_mod) + anchor_trace_fun(anchor_sch) + anchor_trace = anchor_sch.trace + + sch = Schedule(target_mod) + + ms.trace_apply.schedule_using_anchor_trace(sch, anchor_trace, Target(target)) + + tvm.ir.assert_structural_equal(ref, sch.mod) + + +def test_dense_add_cpu(): + def apply_anchor_trace(sch: Schedule) -> None: + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2] + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64] + ) + l17, l18, l19, l20 = sch.split( + loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True + ) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v26 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) + sch.enter_postproc() + b27 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") + b28, b29 = sch.get_child_blocks(b27) + l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) + l40 = sch.fuse(l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l40) + l41 = sch.fuse(l39, preserve_unit_iters=True) + sch.vectorize(loop=l41) + l42, l43, l44 = sch.get_loops(block=b29) + l45 = sch.fuse(l42, preserve_unit_iters=True) + sch.parallel(loop=l45) + l46 = sch.fuse(l44, preserve_unit_iters=True) + sch.vectorize(loop=l46) + b47 = sch.get_block(name="T_matmul_NT", func_name="main") + l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) + b57 = sch.decompose_reduction(block=b47, loop=l51) + b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") + sch.transform_layout( + block=b58, + buffer=("read", 2), + index_map=tvm.tir.IndexMap.from_func( + lambda i0, i1: ( + floordiv(i0, 64), + i1, + floormod(i0, 64), + ), + inverse_index_map=lambda i0, i1, i2: ( + ((i0 * 64) + i2), + i1, + ), + ), + pad_value=None, + ) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + + verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) + + +def test_dense_add_cpu_no_write_cache(): + def apply_trace(sch): + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, n=4, max_innermost_factor=64, decision=[4, 4, 4, 2] + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[1, 1, 4, 32] + ) + l17, l18, l19, l20 = sch.split( + loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True + ) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[8, 16]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v25 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25) + sch.enter_postproc() + b26 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.unroll_explicit") + (b27,) = sch.get_child_blocks(b26) + l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) + l38 = sch.fuse(l28, l29, l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l38) + l39 = sch.fuse(l37, preserve_unit_iters=True) + sch.vectorize(loop=l39) + sch.annotate(block_or_loop=l38, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l38, ann_key="pragma_unroll_explicit", ann_val=1) + b40 = sch.get_block(name="T_matmul_NT", func_name="main") + l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b40) + b48 = sch.decompose_reduction(block=b40, loop=l42) + b49 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b50 = sch.cache_read(block=b49, read_buffer_index=2, storage_scope="global") + sch.transform_layout( + block=b49, + buffer=("read", 2), + index_map=tvm.tir.IndexMap.from_func( + lambda i0, i1: ( + floordiv(i1, 16), + floordiv(i0, 32), + floormod(i1, 16), + floormod(i0, 32), + ), + inverse_index_map=lambda i0, i1, i2, i3: ( + ((i1 * 32) + i3), + ((i0 * 16) + i2), + ), + ), + pad_value=None, + ) + sch.annotate(block_or_loop=b50, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + + verify(Dense, apply_trace, DenseAdd, "llvm", DenseAdd_cpu_no_write_cache) + + +def test_dense_add_gpu(): + def apply_anchor_trace(sch: Schedule) -> None: + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile( + loop=l2, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1] + ) + l10, l11, l12, l13, l14 = sch.split( + loop=l2, factors=[v5, v6, v7, v8, v9], preserve_unit_iters=True + ) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile( + loop=l3, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1] + ) + l20, l21, l22, l23, l24 = sch.split( + loop=l3, factors=[v15, v16, v17, v18, v19], preserve_unit_iters=True + ) + v25, v26, v27 = sch.sample_perfect_tile( + loop=l4, n=3, max_innermost_factor=64, decision=[32, 1, 4] + ) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27], preserve_unit_iters=True) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20, preserve_unit_iters=True) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21, preserve_unit_iters=True) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22, preserve_unit_iters=True) + sch.bind(loop=l33, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=16 + ) + sch.annotate( + block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256 + ) + b34 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b34, loop=l33, preserve_unit_loops=True, index=-1) + b35 = sch.cache_read( + block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0] + ) + sch.compute_at(block=b35, loop=l28, preserve_unit_loops=True, index=-1) + l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35) + l42 = sch.fuse(l40, l41, preserve_unit_iters=True) + v43 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read( + block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0] + ) + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True, index=-1) + l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50, preserve_unit_iters=True) + v52 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52) + v53 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=2, + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53) + sch.enter_postproc() + sch.unannotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch") + l54, l55, l56, l57, l58 = sch.get_loops(block=b35) + l59, l60, l61 = sch.split(loop=l58, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l61) + sch.bind(loop=l60, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch") + l62, l63, l64, l65, l66 = sch.get_loops(block=b44) + l67, l68, l69 = sch.split(loop=l66, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l69) + sch.bind(loop=l68, thread_axis="threadIdx.x") + b70 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b70, ann_key="meta_schedule.unroll_explicit") + b71, b72, b73, b74 = sch.get_child_blocks(b70) + l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71) + sch.annotate(block_or_loop=l75, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l75, ann_key="pragma_unroll_explicit", ann_val=1) + l82, l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b72) + sch.annotate(block_or_loop=l82, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l82, ann_key="pragma_unroll_explicit", ann_val=1) + l89, l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b73) + sch.annotate(block_or_loop=l89, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l89, ann_key="pragma_unroll_explicit", ann_val=1) + l99, l100, l101, l102, l103 = sch.get_loops(block=b74) + sch.annotate(block_or_loop=l99, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l99, ann_key="pragma_unroll_explicit", ann_val=1) + b104 = sch.get_block(name="T_matmul_NT", func_name="main") + l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) + b115 = sch.decompose_reduction(block=b104, loop=l108) + + verify(Dense, apply_anchor_trace, DenseAdd, "cuda", DenseAdd_scheduled_gpu) + + +def test_conv2d_int8_tensorcore(): + def apply_trace(sch): + b0 = sch.get_block(name="pad_temp", func_name="main") + b1 = sch.get_block(name="conv2d_nhwc", func_name="main") + b2 = sch.get_block(name="T_subtract", func_name="main") + b3 = sch.get_block(name="T_add", func_name="main") + b4 = sch.get_block(name="T_cast", func_name="main") + b5 = sch.get_block(name="T_multiply", func_name="main") + b6 = sch.get_block(name="T_add_1", func_name="main") + b7 = sch.get_block(name="T_right_shift", func_name="main") + b8 = sch.get_block(name="T_cast_1", func_name="main") + b9 = sch.get_block(name="T_add_2", func_name="main") + b10 = sch.get_block(name="compute", func_name="main") + b11 = sch.get_block(name="T_cast_2", func_name="main") + b12 = sch.get_block(name="T_cast_3", func_name="main") + b13 = sch.get_block(name="T_subtract_1", func_name="main") + b14 = sch.get_block(name="compute_1", func_name="main") + b15 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + b16 = sch.reindex(block=b1, buffer=("write", 0)) + b17 = sch.reindex(block=b1, buffer=("read", 0)) + b18 = sch.reindex(block=b1, buffer=("read", 1)) + sch.transform_layout( + block=b1, + buffer=("read", 0), + index_map=lambda nn, yy, xx, rc: ( + (((nn * 3136) + (yy * 56)) + xx), + rc, + ), + pad_value=None, + ) + sch.transform_layout( + block=b1, + buffer=("read", 1), + index_map=lambda ff, ry, rx, rc: ( + ry, + rx, + ff, + rc, + ), + pad_value=None, + ) + sch.transform_layout( + block=b1, + buffer=("write", 0), + index_map=lambda nn, yy, xx, ff: ( + (((nn * 3136) + (yy * 56)) + xx), + ff, + ), + pad_value=None, + ) + sch.transform_block_layout( + block=b16, + index_map=lambda nn, yy, xx, ff: ( + (((nn * 3136) + (yy * 56)) + xx), + ff, + ), + ) + sch.transform_block_layout( + block=b17, + index_map=lambda nn, yy, xx, rc: ( + (((nn * 3136) + (yy * 56)) + xx), + rc, + ), + ) + sch.transform_block_layout( + block=b18, + index_map=lambda ff, ry, rx, rc: ( + ry, + rx, + ff, + rc, + ), + ) + sch.transform_block_layout( + block=b1, + index_map=lambda nn, yy, xx, ff, ry, rx, rc: ( + ry, + rx, + (((nn * 3136) + (yy * 56)) + xx), + ff, + rc, + ), + ) + l19, l20, l21, l22, l23 = sch.get_loops(block=b1) + l24, l25 = sch.split(loop=l23, factors=[None, 16], preserve_unit_iters=True) + l26, l27 = sch.split(loop=l22, factors=[None, 16], preserve_unit_iters=True) + l28, l29 = sch.split(loop=l21, factors=[None, 16], preserve_unit_iters=True) + l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b1) + sch.reorder(l34, l36, l29, l27, l25) + b38 = sch.blockize(loop=l29) + sch.annotate( + block_or_loop=b38, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_sync_16x16x16_s8s8s32_trans", + ) + sch.annotate( + block_or_loop=b38, + ann_key="meta_schedule.auto_tensorize_init", + ann_val="wmma_fill_16x16x16_s32", + ) + sch.annotate(block_or_loop=b38, ann_key="warp_execution", ann_val=1) + l39, l40, l41, l42, l43 = sch.get_loops(block=b38) + v44, v45, v46 = sch.sample_perfect_tile( + loop=l39, n=3, max_innermost_factor=4, decision=[1, 1, 1] + ) + l47, l48, l49 = sch.split(loop=l39, factors=[v44, v45, v46], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile( + loop=l40, n=3, max_innermost_factor=4, decision=[1, 1, 1] + ) + l53, l54, l55 = sch.split(loop=l40, factors=[v50, v51, v52], preserve_unit_iters=True) + v56, v57, v58, v59, v60 = sch.sample_perfect_tile( + loop=l41, n=5, max_innermost_factor=4, decision=[392, 1, 8, 1, 1] + ) + l61, l62, l63, l64, l65 = sch.split( + loop=l41, factors=[v56, v57, v58, v59, v60], preserve_unit_iters=True + ) + v66, v67, v68, v69, v70 = sch.sample_perfect_tile( + loop=l42, n=5, max_innermost_factor=4, decision=[8, 1, 2, 1, 1] + ) + l71, l72, l73, l74, l75 = sch.split( + loop=l42, factors=[v66, v67, v68, v69, v70], preserve_unit_iters=True + ) + v76, v77, v78 = sch.sample_perfect_tile( + loop=l43, n=3, max_innermost_factor=4, decision=[2, 1, 2] + ) + l79, l80, l81 = sch.split(loop=l43, factors=[v76, v77, v78], preserve_unit_iters=True) + sch.reorder( + l61, + l71, + l62, + l72, + l63, + l73, + l47, + l53, + l79, + l48, + l54, + l80, + l64, + l74, + l49, + l55, + l81, + l65, + l75, + ) + l82 = sch.fuse(l61, l71, preserve_unit_iters=True) + sch.bind(loop=l82, thread_axis="blockIdx.x") + l83 = sch.fuse(l62, l72, preserve_unit_iters=True) + sch.bind(loop=l83, thread_axis="vthread.x") + l84 = sch.fuse(l63, l73, preserve_unit_iters=True) + sch.bind(loop=l84, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b38, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32 + ) + sch.annotate( + block_or_loop=b38, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024 + ) + b85 = sch.cache_write(block=b38, write_buffer_index=0, storage_scope="shared") + sch.reverse_compute_at(block=b85, loop=l83, preserve_unit_loops=True, index=-1) + b86 = sch.cache_write(block=b38, write_buffer_index=0, storage_scope="wmma.accumulator") + sch.reverse_compute_at(block=b86, loop=l84, preserve_unit_loops=True, index=-1) + v87 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=0, + ) + sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) + sch.reverse_compute_inline(block=b16) + l88, l89, l90, l91, l92 = sch.get_loops(block=b86) + l93, l94 = sch.split(loop=l92, factors=[None, 16], preserve_unit_iters=True) + l95, l96 = sch.split(loop=l91, factors=[None, 16], preserve_unit_iters=True) + l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b86) + sch.reorder(l102, l96, l94) + b104 = sch.blockize(loop=l96) + sch.annotate( + block_or_loop=b104, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_store_16x16x16_s32_shared", + ) + b105 = sch.cache_read( + block=b38, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b38] + ) + sch.compute_at(block=b105, loop=l79, preserve_unit_loops=True, index=-1) + l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b105) + l114 = sch.fuse(l112, l113, preserve_unit_iters=True) + v115 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=5, + ) + sch.annotate(block_or_loop=b105, ann_key="meta_schedule.cooperative_fetch", ann_val=v115) + b116 = sch.cache_read( + block=b38, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b38] + ) + sch.compute_at(block=b116, loop=l79, preserve_unit_loops=True, index=-1) + l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b116) + l127 = sch.fuse(l123, l124, l125, l126, preserve_unit_iters=True) + v128 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=4, + ) + sch.annotate(block_or_loop=b116, ann_key="meta_schedule.cooperative_fetch", ann_val=v128) + b129 = sch.cache_read(block=b38, read_buffer_index=0, storage_scope="wmma.matrix_a") + sch.compute_at(block=b129, loop=l80, preserve_unit_loops=True, index=-1) + l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b129) + l141, l142 = sch.split(loop=l140, factors=[None, 16], preserve_unit_iters=True) + l143, l144 = sch.split(loop=l139, factors=[None, 16], preserve_unit_iters=True) + ( + l145, + l146, + l147, + l148, + l149, + l150, + l151, + l152, + l153, + l154, + l155, + l156, + l157, + ) = sch.get_loops(block=b129) + sch.reorder(l156, l144, l142) + b158 = sch.blockize(loop=l144) + sch.annotate( + block_or_loop=b158, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_load_16x16x16_s8_a", + ) + b159 = sch.cache_read(block=b38, read_buffer_index=1, storage_scope="wmma.matrix_b") + sch.compute_at(block=b159, loop=l80, preserve_unit_loops=True, index=-1) + ( + l160, + l161, + l162, + l163, + l164, + l165, + l166, + l167, + l168, + l169, + l170, + l171, + l172, + ) = sch.get_loops(block=b159) + l173, l174 = sch.split(loop=l172, factors=[None, 16], preserve_unit_iters=True) + l175, l176 = sch.split(loop=l171, factors=[None, 16], preserve_unit_iters=True) + ( + l177, + l178, + l179, + l180, + l181, + l182, + l183, + l184, + l185, + l186, + l187, + l188, + l189, + l190, + l191, + ) = sch.get_loops(block=b159) + sch.reorder(l190, l176, l174) + b192 = sch.blockize(loop=l176) + sch.annotate( + block_or_loop=b192, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_load_16x16x16_s8_b_trans", + ) + sch.compute_inline(block=b17) + sch.compute_inline(block=b18) + sch.storage_align(block=b105, buffer_index=0, axis=-2, factor=32, offset=16) + sch.storage_align(block=b116, buffer_index=0, axis=-2, factor=32, offset=16) + sch.reverse_compute_inline(block=b14) + sch.reverse_compute_inline(block=b13) + sch.reverse_compute_inline(block=b12) + sch.reverse_compute_inline(block=b11) + sch.reverse_compute_inline(block=b10) + sch.reverse_compute_inline(block=b9) + sch.reverse_compute_inline(block=b8) + sch.reverse_compute_inline(block=b7) + sch.reverse_compute_inline(block=b6) + sch.reverse_compute_inline(block=b5) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.compute_inline(block=b0) + v193 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=3, + ) + sch.annotate(block_or_loop=b15, ann_key="meta_schedule.unroll_explicit", ann_val=v193) + sch.enter_postproc() + sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch") + l194, l195, l196, l197 = sch.get_loops(block=b85) + l198, l199 = sch.split(loop=l197, factors=[None, 16], preserve_unit_iters=True) + sch.bind(loop=l199, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.cooperative_fetch") + l200, l201, l202, l203, l204, l205, l206 = sch.get_loops(block=b105) + l207, l208, l209 = sch.split(loop=l206, factors=[None, 16, 16], preserve_unit_iters=True) + sch.vectorize(loop=l209) + sch.bind(loop=l208, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b116, ann_key="meta_schedule.cooperative_fetch") + l210, l211, l212, l213, l214, l215, l216 = sch.get_loops(block=b116) + l217, l218, l219 = sch.split(loop=l216, factors=[None, 16, 8], preserve_unit_iters=True) + sch.vectorize(loop=l219) + sch.bind(loop=l218, thread_axis="threadIdx.x") + b220 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b220, ann_key="meta_schedule.unroll_explicit") + b221, b222, b223, b224, b225, b226, b227 = sch.get_child_blocks(b220) + l228, l229, l230, l231, l232, l233, l234, l235, l236 = sch.get_loops(block=b221) + sch.annotate(block_or_loop=l228, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l228, ann_key="pragma_unroll_explicit", ann_val=1) + l237, l238, l239, l240, l241, l242, l243, l244, l245 = sch.get_loops(block=b222) + sch.annotate(block_or_loop=l237, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l237, ann_key="pragma_unroll_explicit", ann_val=1) + l246, l247, l248, l249, l250, l251, l252, l253, l254, l255, l256 = sch.get_loops(block=b223) + sch.annotate(block_or_loop=l246, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l246, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l257, + l258, + l259, + l260, + l261, + l262, + l263, + l264, + l265, + l266, + l267, + l268, + l269, + ) = sch.get_loops(block=b224) + sch.annotate(block_or_loop=l257, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l257, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l270, + l271, + l272, + l273, + l274, + l275, + l276, + l277, + l278, + l279, + l280, + l281, + l282, + l283, + l284, + l285, + ) = sch.get_loops(block=b225) + sch.annotate(block_or_loop=l270, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l270, ann_key="pragma_unroll_explicit", ann_val=1) + l286, l287, l288, l289, l290 = sch.get_loops(block=b226) + sch.annotate(block_or_loop=l286, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l286, ann_key="pragma_unroll_explicit", ann_val=1) + l291, l292, l293, l294, l295 = sch.get_loops(block=b227) + sch.annotate(block_or_loop=l291, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l291, ann_key="pragma_unroll_explicit", ann_val=1) + b296 = sch.get_block(name="conv2d_nhwc_o", func_name="main") + ( + l297, + l298, + l299, + l300, + l301, + l302, + l303, + l304, + l305, + l306, + l307, + l308, + l309, + l310, + l311, + l312, + ) = sch.get_loops(block=b296) + b313 = sch.decompose_reduction(block=b296, loop=l302) + sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize") + sch.annotate( + block_or_loop=b313, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_fill_16x16x16_s32", + ) + sch.unannotate(block_or_loop=b296, ann_key="meta_schedule.auto_tensorize_init") + sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize_init") + b314 = sch.get_block(name="conv2d_nhwc_o_init", func_name="main") + sch.unannotate(block_or_loop=b314, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b314, tensor_intrin="wmma_fill_16x16x16_s32") + b315 = sch.get_block(name="pad_temp_reindex_shared_wmma.matrix_a_o", func_name="main") + sch.unannotate(block_or_loop=b315, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b315, tensor_intrin="wmma_load_16x16x16_s8_a") + b316 = sch.get_block(name="p1_reindex_shared_wmma.matrix_b_o", func_name="main") + sch.unannotate(block_or_loop=b316, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b316, tensor_intrin="wmma_load_16x16x16_s8_b_trans") + b317 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main") + sch.unannotate(block_or_loop=b317, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b317, tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans") + b318 = sch.get_block(name="conv2d_nhwc_reindex_shared_wmma.accumulator_o", func_name="main") + sch.unannotate(block_or_loop=b318, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b318, tensor_intrin="wmma_store_16x16x16_s32_shared") + + verify(Conv2dInt8, apply_trace, Conv2dInt8_target, "cuda", Conv2dInt8_tensorcore_scheduled) + + +def test_conv2d_int8_vnni(): + def apply_trace(sch): + b0 = sch.get_block(name="compile_engine_const", func_name="main") + b1 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") + b2 = sch.get_block(name="T_add", func_name="main") + b3 = sch.get_block(name="T_cast", func_name="main") + b4 = sch.get_block(name="T_multiply", func_name="main") + b5 = sch.get_block(name="compile_engine_const_1", func_name="main") + b6 = sch.get_block(name="T_add_1", func_name="main") + b7 = sch.get_block(name="T_floor", func_name="main") + b8 = sch.get_block(name="T_cast_1", func_name="main") + b9 = sch.get_block(name="compute", func_name="main") + b10 = sch.get_block(name="T_cast_2", func_name="main") + b11 = sch.get_block(name="T_cast_3", func_name="main") + b12 = sch.get_block(name="T_subtract", func_name="main") + b13 = sch.get_block(name="T_multiply_1", func_name="main") + b14 = sch.get_block(name="compile_engine_const_2", func_name="main") + b15 = sch.get_block(name="T_add_2", func_name="main") + b16 = sch.get_block(name="T_floor_1", func_name="main") + b17 = sch.get_block(name="T_cast_4", func_name="main") + b18 = sch.get_block(name="T_add_3", func_name="main") + b19 = sch.get_block(name="compute_1", func_name="main") + b20 = sch.get_block(name="T_cast_5", func_name="main") + b21 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b20) + sch.compute_inline(block=b19) + sch.compute_inline(block=b18) + sch.compute_inline(block=b17) + sch.compute_inline(block=b16) + sch.compute_inline(block=b15) + sch.compute_inline(block=b14) + sch.compute_inline(block=b13) + sch.compute_inline(block=b12) + sch.compute_inline(block=b11) + sch.compute_inline(block=b10) + sch.compute_inline(block=b9) + sch.compute_inline(block=b8) + sch.compute_inline(block=b7) + sch.compute_inline(block=b6) + sch.compute_inline(block=b5) + sch.compute_inline(block=b4) + sch.compute_inline(block=b3) + sch.compute_inline(block=b2) + sch.compute_inline(block=b0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l22, l23, l24, l25, l26, l27, l28, l29, l30, l31 = sch.get_loops(block=b1) + l32, l33 = sch.split(loop=l31, factors=[None, 4], preserve_unit_iters=True) + l34, l35 = sch.split(loop=l26, factors=[None, 16], preserve_unit_iters=True) + l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b1) + sch.reorder(l42, l43, l44, l45, l46, l35, l33) + b48 = sch.blockize(loop=l35) + sch.annotate( + block_or_loop=b48, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni" + ) + l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b48) + v59, v60, v61, v62 = sch.sample_perfect_tile( + loop=l49, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1] + ) + l63, l64, l65, l66 = sch.split( + loop=l49, factors=[v59, v60, v61, v62], preserve_unit_iters=True + ) + v67, v68, v69, v70 = sch.sample_perfect_tile( + loop=l50, n=4, max_innermost_factor=64, decision=[4, 32, 1, 1] + ) + l71, l72, l73, l74 = sch.split( + loop=l50, factors=[v67, v68, v69, v70], preserve_unit_iters=True + ) + v75, v76, v77, v78 = sch.sample_perfect_tile( + loop=l51, n=4, max_innermost_factor=64, decision=[1, 7, 1, 1] + ) + l79, l80, l81, l82 = sch.split( + loop=l51, factors=[v75, v76, v77, v78], preserve_unit_iters=True + ) + v83, v84, v85, v86 = sch.sample_perfect_tile( + loop=l52, n=4, max_innermost_factor=64, decision=[1, 1, 1, 7] + ) + l87, l88, l89, l90 = sch.split( + loop=l52, factors=[v83, v84, v85, v86], preserve_unit_iters=True + ) + v91, v92, v93, v94 = sch.sample_perfect_tile( + loop=l53, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1] + ) + l95, l96, l97, l98 = sch.split( + loop=l53, factors=[v91, v92, v93, v94], preserve_unit_iters=True + ) + v99, v100 = sch.sample_perfect_tile(loop=l54, n=2, max_innermost_factor=64, decision=[1, 1]) + l101, l102 = sch.split(loop=l54, factors=[v99, v100], preserve_unit_iters=True) + v103, v104 = sch.sample_perfect_tile( + loop=l55, n=2, max_innermost_factor=64, decision=[1, 1] + ) + l105, l106 = sch.split(loop=l55, factors=[v103, v104], preserve_unit_iters=True) + v107, v108 = sch.sample_perfect_tile( + loop=l56, n=2, max_innermost_factor=64, decision=[4, 8] + ) + l109, l110 = sch.split(loop=l56, factors=[v107, v108], preserve_unit_iters=True) + v111, v112 = sch.sample_perfect_tile( + loop=l57, n=2, max_innermost_factor=64, decision=[4, 1] + ) + l113, l114 = sch.split(loop=l57, factors=[v111, v112], preserve_unit_iters=True) + v115, v116 = sch.sample_perfect_tile( + loop=l58, n=2, max_innermost_factor=64, decision=[1, 1] + ) + l117, l118 = sch.split(loop=l58, factors=[v115, v116], preserve_unit_iters=True) + sch.reorder( + l63, + l71, + l79, + l87, + l95, + l64, + l72, + l80, + l88, + l96, + l101, + l105, + l109, + l113, + l117, + l65, + l73, + l81, + l89, + l97, + l102, + l106, + l110, + l114, + l118, + l66, + l74, + l82, + l90, + l98, + ) + (b119,) = sch.get_consumers(block=b48) + sch.reverse_compute_at(block=b119, loop=l96, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.parallel", ann_val=96) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.vectorize", ann_val=64) + v120 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2 + ) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.unroll_explicit", ann_val=v120) + sch.enter_postproc() + b121 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.unroll_explicit") + b122, b123 = sch.get_child_blocks(b121) + ( + l124, + l125, + l126, + l127, + l128, + l129, + l130, + l131, + l132, + l133, + l134, + l135, + l136, + l137, + l138, + l139, + l140, + l141, + l142, + l143, + l144, + l145, + l146, + l147, + l148, + l149, + l150, + l151, + l152, + l153, + ) = sch.get_loops(block=b122) + l154 = sch.fuse(l124, l125, l126, l127, l128, l129, l130, preserve_unit_iters=True) + sch.parallel(loop=l154) + sch.annotate(block_or_loop=l154, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l154, ann_key="pragma_unroll_explicit", ann_val=1) + l155, l156, l157, l158, l159, l160, l161, l162, l163 = sch.get_loops(block=b123) + l164 = sch.fuse(l163, preserve_unit_iters=True) + sch.vectorize(loop=l164) + sch.annotate(block_or_loop=l155, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l155, ann_key="pragma_unroll_explicit", ann_val=1) + b165 = sch.get_block(name="conv2d_NCHWc_int8_o", func_name="main") + ( + l166, + l167, + l168, + l169, + l170, + l171, + l172, + l173, + l174, + l175, + l176, + l177, + l178, + l179, + l180, + l181, + l182, + l183, + l184, + l185, + l186, + l187, + l188, + l189, + ) = sch.get_loops(block=b165) + b190 = sch.decompose_reduction(block=b165, loop=l172) + sch.unannotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize") + sch.annotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize", ann_val="") + b191 = sch.get_block(name="conv2d_NCHWc_int8_o_init", func_name="main") + sch.unannotate(block_or_loop=b191, ann_key="meta_schedule.auto_tensorize") + (b192,) = sch.get_child_blocks(b191) + (l193,) = sch.get_loops(block=b192) + sch.vectorize(loop=l193) + b194 = sch.get_block(name="conv2d_NCHWc_int8_o_update", func_name="main") + sch.unannotate(block_or_loop=b194, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b194, tensor_intrin="dot_16x4_vnni") + + vnni_id = llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512") + verify( + Conv2dInt8_NCHWc, + apply_trace, + Conv2dInt8_NCHWc_target, + "llvm -mcpu=cascadelake", + get_conv2d_vnni_mod(vnni_id), + ) + + +def test_winograd_gpu(): + def apply_trace(sch): + b0 = sch.get_block(name="B", func_name="main") + b1 = sch.get_block(name="data_pack", func_name="main") + b2 = sch.get_block(name="bgemm", func_name="main") + b3 = sch.get_block(name="A", func_name="main") + b4 = sch.get_block(name="inverse", func_name="main") + b5 = sch.get_block(name="conv2d_winograd", func_name="main") + b6 = sch.get_block(name="T_add", func_name="main") + b7 = sch.get_block(name="T_relu", func_name="main") + b8 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b0) + (b9,) = sch.get_producers(block=b1) + (b10,) = sch.get_producers(block=b9) + l11, l12, l13, l14, l15, l16 = sch.get_loops(block=b1) + v17, v18 = sch.sample_perfect_tile( + loop=l13, n=2, max_innermost_factor=64, decision=[14, 14] + ) + l19, l20 = sch.split(loop=l13, factors=[v17, v18], preserve_unit_iters=True) + v21, v22 = sch.sample_perfect_tile(loop=l14, n=2, max_innermost_factor=64, decision=[8, 8]) + l23, l24 = sch.split(loop=l14, factors=[v21, v22], preserve_unit_iters=True) + sch.unroll(loop=l11) + sch.unroll(loop=l12) + sch.unroll(loop=l15) + sch.unroll(loop=l16) + sch.reorder(l19, l23, l20, l24, l11, l12, l15, l16) + sch.compute_at(block=b9, loop=l24, preserve_unit_loops=True, index=-1) + sch.set_scope(block=b9, buffer_index=0, storage_scope="local") + sch.compute_inline(block=b10) + l25, l26, l27, l28, l29, l30, l31, l32 = sch.get_loops(block=b1) + l33 = sch.fuse(l25, l26, l27, l28, preserve_unit_iters=True) + v34 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l35, l36 = sch.split(loop=l33, factors=[None, v34], preserve_unit_iters=True) + sch.bind(loop=l35, thread_axis="blockIdx.x") + sch.bind(loop=l36, thread_axis="threadIdx.x") + sch.compute_inline(block=b3) + l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b4) + v43, v44 = sch.sample_perfect_tile(loop=l39, n=2, max_innermost_factor=64, decision=[28, 7]) + l45, l46 = sch.split(loop=l39, factors=[v43, v44], preserve_unit_iters=True) + v47, v48 = sch.sample_perfect_tile(loop=l40, n=2, max_innermost_factor=64, decision=[2, 32]) + l49, l50 = sch.split(loop=l40, factors=[v47, v48], preserve_unit_iters=True) + sch.unroll(loop=l37) + sch.unroll(loop=l38) + sch.unroll(loop=l41) + sch.unroll(loop=l42) + sch.reorder(l45, l49, l46, l50, l37, l38, l41, l42) + l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b4) + l59 = sch.fuse(l51, l52, l53, l54, preserve_unit_iters=True) + v60 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=4, + ) + l61, l62 = sch.split(loop=l59, factors=[None, v60], preserve_unit_iters=True) + sch.bind(loop=l61, thread_axis="blockIdx.x") + sch.bind(loop=l62, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l63, l64, l65, l66, l67 = sch.get_loops(block=b2) + v68, v69, v70, v71, v72 = sch.sample_perfect_tile( + loop=l63, n=5, max_innermost_factor=64, decision=[1, 2, 3, 1, 1] + ) + l73, l74, l75, l76, l77 = sch.split( + loop=l63, factors=[v68, v69, v70, v71, v72], preserve_unit_iters=True + ) + v78, v79, v80, v81, v82 = sch.sample_perfect_tile( + loop=l64, n=5, max_innermost_factor=64, decision=[6, 1, 1, 1, 1] + ) + l83, l84, l85, l86, l87 = sch.split( + loop=l64, factors=[v78, v79, v80, v81, v82], preserve_unit_iters=True + ) + v88, v89, v90, v91, v92 = sch.sample_perfect_tile( + loop=l65, n=5, max_innermost_factor=64, decision=[7, 2, 1, 14, 1] + ) + l93, l94, l95, l96, l97 = sch.split( + loop=l65, factors=[v88, v89, v90, v91, v92], preserve_unit_iters=True + ) + v98, v99, v100, v101, v102 = sch.sample_perfect_tile( + loop=l66, n=5, max_innermost_factor=64, decision=[4, 1, 16, 1, 1] + ) + l103, l104, l105, l106, l107 = sch.split( + loop=l66, factors=[v98, v99, v100, v101, v102], preserve_unit_iters=True + ) + v108, v109, v110 = sch.sample_perfect_tile( + loop=l67, n=3, max_innermost_factor=64, decision=[2, 2, 16] + ) + l111, l112, l113 = sch.split(loop=l67, factors=[v108, v109, v110], preserve_unit_iters=True) + sch.reorder( + l73, + l83, + l93, + l103, + l74, + l84, + l94, + l104, + l75, + l85, + l95, + l105, + l111, + l112, + l76, + l86, + l96, + l106, + l113, + l77, + l87, + l97, + l107, + ) + l114 = sch.fuse(l73, l83, l93, l103, preserve_unit_iters=True) + sch.bind(loop=l114, thread_axis="blockIdx.x") + l115 = sch.fuse(l74, l84, l94, l104, preserve_unit_iters=True) + sch.bind(loop=l115, thread_axis="vthread.x") + l116 = sch.fuse(l75, l85, l95, l105, preserve_unit_iters=True) + sch.bind(loop=l116, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b2, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32 + ) + sch.annotate( + block_or_loop=b2, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024 + ) + b117 = sch.cache_write(block=b2, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b117, loop=l116, preserve_unit_loops=True, index=-1) + b118 = sch.cache_read( + block=b2, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b2] + ) + sch.compute_at(block=b118, loop=l111, preserve_unit_loops=True, index=-1) + l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b118) + l127 = sch.fuse(l123, l124, l125, l126, preserve_unit_iters=True) + v128 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b118, ann_key="meta_schedule.cooperative_fetch", ann_val=v128) + b129 = sch.cache_read( + block=b2, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b2] + ) + sch.compute_at(block=b129, loop=l111, preserve_unit_loops=True, index=-1) + l130, l131, l132, l133, l134, l135, l136, l137 = sch.get_loops(block=b129) + l138 = sch.fuse(l134, l135, l136, l137, preserve_unit_iters=True) + v139 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b129, ann_key="meta_schedule.cooperative_fetch", ann_val=v139) + sch.reverse_compute_inline(block=b7) + sch.reverse_compute_inline(block=b6) + v140 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=4, + ) + sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140) + l141, l142, l143, l144 = sch.get_loops(block=b5) + l145 = sch.fuse(l141, l142, l143, l144, preserve_unit_iters=True) + v146 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l147, l148 = sch.split(loop=l145, factors=[None, v146], preserve_unit_iters=True) + sch.bind(loop=l147, thread_axis="blockIdx.x") + sch.bind(loop=l148, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b118, ann_key="meta_schedule.cooperative_fetch") + l149, l150, l151, l152, l153 = sch.get_loops(block=b118) + l154, l155, l156 = sch.split(loop=l153, factors=[None, 48, 4], preserve_unit_iters=True) + sch.vectorize(loop=l156) + sch.bind(loop=l155, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b129, ann_key="meta_schedule.cooperative_fetch") + l157, l158, l159, l160, l161 = sch.get_loops(block=b129) + l162, l163, l164 = sch.split(loop=l161, factors=[None, 48, 4], preserve_unit_iters=True) + sch.vectorize(loop=l164) + sch.bind(loop=l163, thread_axis="threadIdx.x") + b165 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b165, ann_key="meta_schedule.unroll_explicit") + b166, b167, b168, b169, b170, b171, b172, b173 = sch.get_child_blocks(b165) + l174, l175, l176, l177, l178, l179 = sch.get_loops(block=b166) + sch.annotate(block_or_loop=l174, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l174, ann_key="pragma_unroll_explicit", ann_val=1) + l180, l181, l182, l183, l184, l185 = sch.get_loops(block=b167) + sch.annotate(block_or_loop=l180, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l180, ann_key="pragma_unroll_explicit", ann_val=1) + l186, l187, l188, l189, l190, l191, l192 = sch.get_loops(block=b168) + sch.annotate(block_or_loop=l186, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l186, ann_key="pragma_unroll_explicit", ann_val=1) + l193, l194, l195, l196, l197, l198, l199 = sch.get_loops(block=b169) + sch.annotate(block_or_loop=l193, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l193, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l200, + l201, + l202, + l203, + l204, + l205, + l206, + l207, + l208, + l209, + l210, + l211, + l212, + l213, + ) = sch.get_loops(block=b170) + sch.annotate(block_or_loop=l200, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l200, ann_key="pragma_unroll_explicit", ann_val=1) + l214, l215, l216, l217, l218, l219, l220 = sch.get_loops(block=b171) + sch.annotate(block_or_loop=l214, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l214, ann_key="pragma_unroll_explicit", ann_val=1) + l221, l222, l223, l224, l225, l226 = sch.get_loops(block=b172) + sch.annotate(block_or_loop=l221, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l221, ann_key="pragma_unroll_explicit", ann_val=1) + l227, l228 = sch.get_loops(block=b173) + sch.annotate(block_or_loop=l227, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l227, ann_key="pragma_unroll_explicit", ann_val=1) + b229 = sch.get_block(name="data_pack", func_name="main") + l230, l231, l232, l233, l234, l235 = sch.get_loops(block=b229) + b236 = sch.decompose_reduction(block=b229, loop=l234) + b237 = sch.get_block(name="bgemm", func_name="main") + ( + l238, + l239, + l240, + l241, + l242, + l243, + l244, + l245, + l246, + l247, + l248, + l249, + l250, + l251, + ) = sch.get_loops(block=b237) + b252 = sch.decompose_reduction(block=b237, loop=l241) + b253 = sch.get_block(name="inverse", func_name="main") + l254, l255, l256, l257, l258, l259 = sch.get_loops(block=b253) + b260 = sch.decompose_reduction(block=b253, loop=l258) + + verify( + Conv2dWinogradAddRelu, + apply_trace, + Conv2dWinogradAddResidualRelu, + "cuda", + Conv2dWinogradAddResidualRelu_scheduled, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index d0bfc913eca6..1f91dc593143 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -26,6 +26,7 @@ from tvm import relay from tvm._ffi import register_func from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN logging.basicConfig( @@ -44,6 +45,7 @@ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if sch.mod.attrs is not None and "dense" not in sch.mod.attrs["task_name"]: return False if dense_block is None: + assert has_block(sch, "compute") dense_block = sch.get_block("compute") assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"]