Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
from .mutate_unroll import MutateUnroll
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates auto unroll step"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateUnroll")
class MutateUnroll(Mutator):
"""Mutator that mutates auto unroll step"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member
)
141 changes: 141 additions & 0 deletions src/meta_schedule/mutator/mutate_unroll.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Check if an instruction is annotate with
* `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit`
* \param inst The instruction to be checked
* \return Whether the instruction is annotated
*/
bool IsAnnotateWithUnroll(const Instruction& inst) {
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_annotate)) {
return false;
}
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::meta_schedule_unroll_explicit ||
ann_key == attr::meta_schedule_unroll_implicit;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

using tir::Instruction;
using tir::Trace;

/*! \brief Create a Mutator that mutates auto unroll step */
class MutateUnrollNode : public MutatorNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.MutateUnroll";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode);

public:
struct Candidate;
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
};

/*! \brief A candidate to be mutated */
struct MutateUnrollNode::Candidate {
/*! \brief The sampling instruction to be mutated */
Instruction inst;
/*! \brief The probability */
std::vector<double> probs;
/*! \brief The decision made */
int decision;
};

/*!
* \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step
* \param trace The trace to be mutated
* \param rand_state The random state
* \param candidates The mutation candidate
* \return Whether a decision is found
*/
bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
MutateUnrollNode::Candidate* candidate) {
using tir::InstructionKind;
using tir::InstructionNode;
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
std::unordered_map<const PrimExprNode*, const InstructionNode*> sample_insts;
std::vector<const InstructionNode*> ann_insts;
sample_insts.reserve(trace->insts.size());
ann_insts.reserve(trace->insts.size());
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode);
sample_insts[var_rv] = inst.get();
} else if (IsAnnotateWithUnroll(inst)) {
ann_insts.push_back(inst.get());
}
}
int n_ann_insts = ann_insts.size();
if (n_ann_insts == 0) {
return false;
}
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
ICHECK_EQ(ann_inst->inputs.size(), 2);
const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode);
ICHECK(sample_insts.count(var_rv));
const InstructionNode* sample_inst = sample_insts.at(var_rv);
ICHECK_EQ(sample_inst->attrs.size(), 2);
candidate->inst = GetRef<Instruction>(sample_inst);
candidate->decision =
Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value;
candidate->probs =
support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1]));
return true;
}

Optional<Trace> MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) {
Candidate candidate;
if (!FindUnrollDecision(trace, rand_state, &candidate)) {
return NullOpt;
}
if (candidate.probs.size() == 0) {
return NullOpt;
}
candidate.probs.erase(candidate.probs.begin() + candidate.decision);
int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)();
if (result >= candidate.decision) {
result += 1;
}
return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true);
}

Mutator Mutator::MutateUnroll() { return Mutator(make_object<MutateUnrollNode>()); }

TVM_REGISTER_NODE_TYPE(MutateUnrollNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll);

} // namespace meta_schedule
} // namespace tvm
114 changes: 114 additions & 0 deletions tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from typing import List

from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.mutator import MutateUnroll, Mutator
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule

# pylint: disable=invalid-name, no-member


@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [512, 512])
B = T.match_buffer(b, [512, 512])
C = T.match_buffer(c, [512, 512])
for i, j, k in T.grid(512, 512, 512): # type: ignore
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore
with T.init():
C[vi, vj] = 0.0 # type: ignore
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


# pylint: enable=invalid-name, no-member


def _sch(decisions: List[List[int]]) -> Schedule:
sch = Schedule(matmul, debug_mask="all")
# pylint: disable=invalid-name
d0, d1, d2 = decisions
b0 = sch.get_block(name="C", func_name="main")
root = sch.get_block(name="root", func_name="main")
sch.get_consumers(block=b0)
b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
l2, l3, l4 = sch.get_loops(block=b0)
v5, v6, v7, v8 = sch.sample_perfect_tile(
loop=l2,
n=4,
max_innermost_factor=64,
decision=d0,
)
l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
v13, v14, v15, v16 = sch.sample_perfect_tile(
loop=l3,
n=4,
max_innermost_factor=64,
decision=d1,
)
l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
v21, v22 = sch.sample_perfect_tile(
loop=l4,
n=2,
max_innermost_factor=64,
decision=d2,
)
l23, l24 = sch.split(loop=l4, factors=[v21, v22])
sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
v57 = sch.sample_categorical(
candidates=[0, 16, 64, 512],
probs=[0.25, 0.25, 0.25, 0.25],
decision=0,
)
sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57)
# pylint: enable=invalid-name
return sch


def _make_mutator(target: Target) -> Mutator:
mutator = MutateUnroll()
mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target))
return mutator


def test_mutate_unroll_matmul():
mutator = _make_mutator(target=Target("llvm --num-cores=16"))
sch = _sch(
decisions=[
[4, 32, 4, 1],
[8, 4, 8, 2],
[512, 1],
],
)
results = set()
for _ in range(100):
trace = mutator.apply(sch.trace)
decision = trace.decisions[trace.insts[-2]]
results.add(decision)
if len(results) == 3:
break
assert len(results) == 3
assert results == {1, 2, 3}


if __name__ == """__main__""":
test_mutate_unroll_matmul()