Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc RewriteReductionBlock();
/*!
* \brief Create a postprocessor that adds thread binding to unbound blocks
* \param max_threadblock The max number of threadblocks in the cuda device.
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteUnboundBlock();
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock);
/*!
* \brief Create a postprocessor that tensorize Tensor Core related components
* \return The postprocessor created.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/postproc/rewrite_unbound_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
class RewriteUnboundBlock(Postproc):
"""A postprocessor that adds thread binding to unbound blocks"""

def __init__(self) -> None:
def __init__(self, max_threadblock: int = 256) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member
max_threadblock,
)
55 changes: 42 additions & 13 deletions src/meta_schedule/postproc/rewrite_unbound_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,25 @@ class RewriteUnboundBlockNode : public PostprocNode {
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {
CHECK(context->target.defined()) << "ValueError: target is not defined";
Optional<Integer> warp_size = context->target.value()->GetAttr<Integer>("thread_warp_size");
CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target";
this->warp_size_ = warp_size.value();
Optional<Integer> max_num_threads =
context->target.value()->GetAttr<Integer>("max_threads_per_block");
CHECK(max_num_threads.defined())
<< "ValueError: missing attribute `max_threads_per_block` in the target";
this->max_num_threads_ = max_num_threads.value();
}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

public:
/*! \brief The cached warp size from Target */
int warp_size_ = -1;
/*! \brief The max number of threads per block from Target */
int max_num_threads_ = -1;
/*! \brief The max number of threadblocks in the cuda device */
int max_threadblock_ = -1;

void VisitAttrs(tvm::AttrVisitor* v) {
// `warp_size_` is not visited
// `max_num_threads_` is not visited
// `max_threadblock_` is not visited
}

static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock";
Expand All @@ -178,7 +183,7 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) {
using tir::BlockRV;
using tir::LoopRV;
using tir::Schedule;
ICHECK_NE(this->warp_size_, -1);
ICHECK_NE(this->max_num_threads_, -1);
std::vector<std::pair<tir::StmtSRef, String>> unbound_blocks =
tir::UnboundBlockFinder::Find(sch->state());
for (const auto& kv : unbound_blocks) {
Expand All @@ -195,18 +200,42 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) {
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
Array<LoopRV> splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)});
ICHECK_EQ(splits.size(), 2);
sch->Bind(splits[0], "blockIdx.x");
sch->Bind(splits[1], "threadIdx.x");
int64_t extent_size = 0;
Array<LoopRV> splits;
if (const int64_t* extent_ptr = tir::GetLoopIntExtent(sch->Get(fused).get())) {
extent_size = *extent_ptr;
if (extent_size > max_threadblock_ * max_num_threads_) {
splits =
sch->Split(fused, {NullOpt, Integer(max_threadblock_), Integer(max_num_threads_)});
ICHECK_EQ(splits.size(), 3);
sch->Reorder({splits[1], splits[2], splits[0]});
sch->Bind(splits[1], "blockIdx.x");
sch->Bind(splits[2], "threadIdx.x");
} else {
ICHECK_NE(extent_size, 0);
splits = sch->Split(
fused,
{NullOpt, Integer(std::min(static_cast<int64_t>(max_num_threads_), extent_size))});
ICHECK_EQ(splits.size(), 2);
sch->Bind(splits[0], "blockIdx.x");
sch->Bind(splits[1], "threadIdx.x");
}
} else {
// loop is dynamic, returns nullptr
splits = sch->Split(fused, {NullOpt, Integer(max_num_threads_)});
ICHECK_EQ(splits.size(), 2);
sch->Bind(splits[0], "blockIdx.x");
sch->Bind(splits[1], "threadIdx.x");
}
}
}
return true;
}

Postproc Postproc::RewriteUnboundBlock() {
Postproc Postproc::RewriteUnboundBlock(int max_threadblock) {
ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>();
n->warp_size_ = -1;
n->max_threadblock_ = max_threadblock;
n->max_num_threads_ = -1;
return Postproc(n);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from tvm.meta_schedule.postproc import RewriteUnboundBlock
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir.schedule.schedule import Schedule


def _target() -> Target:
return Target("cuda", host="llvm")
return Target("cuda --max_threads_per_block=1024", host="llvm")


def _create_context(mod, target) -> TuneContext:
Expand Down Expand Up @@ -63,11 +64,11 @@ class After_cooperative_fetch:
def main(var_A: T.handle, var_B: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
B = T.match_buffer(var_B, [512, 512], dtype="float32")
for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("C"):
vi = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) // 512)
vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512)
vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512)
vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512)
B[vi, vj] = A[vi, vj] + 1.0


Expand All @@ -94,23 +95,180 @@ class After_norm_bmn:
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None:
C = T.alloc_buffer([1], dtype="float32")
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
for i1, i2 in T.grid(256, 256):
with T.block("C"):
b = T.axis.S(1, 0)
i, j = T.axis.remap("RR", [i1, i2])
T.where(i0_fused_1 < 1)
with T.init():
C[b] = T.float32(0)
C[b] = C[b] + A[b, i, j] * A[b, i, j]
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.S(1, 0)
T.where(i0_fused_1 < 1)
D[b] = T.sqrt(C[b], dtype="float32")


@tvm.script.ir_module
class Bert_fused_reshape_transpose_reshape:
@T.prim_func
def main(
placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]
) -> None:
for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536, 32):
with T.block("T_reshape_1"):
ax0 = T.axis.spatial(64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768)
ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768)
T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = placeholder[
((ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12,
(ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64,
ax1 % 64 % 64,
]


@tvm.script.ir_module
class Bert_fused_reshape_transpose_reshape_large:
@T.prim_func
def main(
placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]
) -> None:
for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536000, 32):
with T.block("T_reshape_1"):
ax0 = T.axis.spatial(64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768)
ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768)
T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = placeholder[
((ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12,
(ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64,
ax1 % 64 % 64,
]


@tvm.script.ir_module
class Bert_fused_reshape_transpose_reshape_after_rub:
@T.prim_func
def main(
placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]
) -> None:
for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.thread_binding(48, thread="blockIdx.x"):
for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape_1"):
ax0 = T.axis.spatial(
64,
(
(
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_1
)
// 32
* 32
+ (
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_1
)
% 32
)
// 768,
)
ax1 = T.axis.spatial(
768,
(
(
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_1
)
// 32
* 32
+ (
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_1
)
% 32
)
% 768,
)
T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = placeholder[
((ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12,
(ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64,
ax1 % 64 % 64,
]


@tvm.script.ir_module
class Bert_fused_reshape_transpose_reshape_after_rub_large:
@T.prim_func
def main(
placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]
) -> None:
# body
# with T.block("root")
for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding(256, thread="blockIdx.x"):
for i0_i1_fused_0_i0_i1_fused_1_fused_2 in T.thread_binding(1024, thread="threadIdx.x"):
for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.serial(188):
with T.block("T_reshape_1"):
ax0 = T.axis.spatial(
64,
(
(
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144
+ i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_2
)
// 32
* 32
+ (
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144
+ i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_2
)
% 32
)
// 768,
)
ax1 = T.axis.spatial(
768,
(
(
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144
+ i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_2
)
// 32
* 32
+ (
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 262144
+ i0_i1_fused_0_i0_i1_fused_1_fused_1 * 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_2
)
% 32
)
% 768,
)
T.where(
(
i0_i1_fused_0_i0_i1_fused_1_fused_0 * 256
+ i0_i1_fused_0_i0_i1_fused_1_fused_1
)
* 1024
+ i0_i1_fused_0_i0_i1_fused_1_fused_2
< 49152000
)
T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = placeholder[
((ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64)
% 12,
(ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64,
ax1 % 64 % 64,
]


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand All @@ -135,6 +293,28 @@ def test_rewrite_norm_bmn():
tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn)


def test_rewrite_cuda_loop_split_no_reduction():
mod = Bert_fused_reshape_transpose_reshape
target = Target("nvidia/nvidia-v100", host="llvm")
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub)


def test_rewrite_cuda_loop_split_no_reduction_large():
mod = Bert_fused_reshape_transpose_reshape_large
target = Target("nvidia/nvidia-v100", host="llvm")
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub_large)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()
test_rewrite_norm_bmn()
test_rewrite_cuda_loop_split_no_reduction()
test_rewrite_cuda_loop_split_no_reduction_large()