diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index f232566785d9..85deb7253e86 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -21,3 +21,4 @@ """ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation +from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py new file mode 100644 index 000000000000..f81953d008d4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates auto unroll step""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateUnroll") +class MutateUnroll(Mutator): + """Mutator that mutates auto unroll step""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc new file mode 100644 index 000000000000..e4454184071c --- /dev/null +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -0,0 +1,141 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if an instruction is annotate with + * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit` + * \param inst The instruction to be checked + * \return Whether the instruction is annotated + */ +bool IsAnnotateWithUnroll(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_unroll_explicit || + ann_key == attr::meta_schedule_unroll_implicit; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates auto unroll step */ +class MutateUnrollNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief A candidate to be mutated */ +struct MutateUnrollNode::Candidate { + /*! \brief The sampling instruction to be mutated */ + Instruction inst; + /*! \brief The probability */ + std::vector probs; + /*! \brief The decision made */ + int decision; +}; + +/*! + * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidates The mutation candidate + * \return Whether a decision is found + */ +bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, + MutateUnrollNode::Candidate* candidate) { + using tir::InstructionKind; + using tir::InstructionNode; + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + std::unordered_map sample_insts; + std::vector ann_insts; + sample_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + sample_insts[var_rv] = inst.get(); + } else if (IsAnnotateWithUnroll(inst)) { + ann_insts.push_back(inst.get()); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + ICHECK(sample_insts.count(var_rv)); + const InstructionNode* sample_inst = sample_insts.at(var_rv); + ICHECK_EQ(sample_inst->attrs.size(), 2); + candidate->inst = GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); + return true; +} + +Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { + Candidate candidate; + if (!FindUnrollDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + if (candidate.probs.size() == 0) { + return NullOpt; + } + candidate.probs.erase(candidate.probs.begin() + candidate.decision); + int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + if (result >= candidate.decision) { + result += 1; + } + return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py new file mode 100644 index 000000000000..3f3fbcafc0db --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateUnroll, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + v57 = sch.sample_categorical( + candidates=[0, 16, 64, 512], + probs=[0.25, 0.25, 0.25, 0.25], + decision=0, + ) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateUnroll() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_unroll_matmul(): + mutator = _make_mutator(target=Target("llvm --num-cores=16")) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + results.add(decision) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {1, 2, 3} + + +if __name__ == """__main__""": + test_mutate_unroll_matmul()