Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 70 additions & 31 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,33 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);

// the longest contiguous reduce extent after flattening
int contiguous_reduce_extent = 1;
std::vector<std::tuple<int, int, bool>> block_threads; // tuple(dim_index, extent, is_reduce)
for (const ThreadEntry& thr : vred) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
}
}
for (const ThreadEntry& thr : vpar) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
}
}
// sort according to dim_index
std::sort(block_threads.begin(), block_threads.end());
for (auto&& thr_attr : block_threads) {
int dim_index, extent;
bool is_reduce;
std::tie(dim_index, extent, is_reduce) = thr_attr;
if (is_reduce) {
contiguous_reduce_extent *= extent;
} else {
break;
}
}

std::vector<Stmt> seq;
std::vector<Var> shared_bufs(size);
std::vector<Stmt> local_vars;
Expand All @@ -238,9 +265,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
if (is_warp_reduction(types)) {
// TODO(tvm-team) sub-warp reduction support.
ICHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction";
if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
//
// This is the index to the reduction variable, one reduction
// variable per warp. Local scope seems easier to reason without
Expand Down Expand Up @@ -269,6 +295,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{
PrimExpr pred = const_true(1);
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (group_extent > 1) {
mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index));
}
seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
Expand All @@ -277,7 +306,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}

// Emit reductions within a warp.
for (int offset = warp_size_ / 2; offset > 0; offset /= 2) {
int start_offset = 1;
while (start_offset * 2 < reduce_extent) {
start_offset *= 2;
}
for (int offset = start_offset; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -323,13 +356,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {

// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformmly writting the same result.
// uniformly writting the same result.
//
for (size_t i = 0; i < size; ++i) {
Var var = shared_bufs[i];
PrimExpr pred = const_true(types[i].lanes());
PrimExpr val = Load(types[i], var, index, pred);
PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0);
PrimExpr splat =
WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, reduce_extent * group_index);
seq.push_back(Store(var, splat, index, pred));
}

Expand All @@ -346,7 +380,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
warp_allocs_.insert(node.get());
}
} else {
int threadx_extent = 1;
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores(size);
Expand All @@ -357,10 +390,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
return SeqStmt::Flatten(stores);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
threadx_extent = vred[0].extent;
}
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
Expand All @@ -372,7 +401,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
reduce_extent, threadx_extent));
reduce_extent, group_extent, contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]));
PrimExpr pred = const_true(types[idx].lanes());
Expand Down Expand Up @@ -402,7 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
const Array<Var>& shared_bufs, PrimExpr reduce_index, PrimExpr group_index,
int reduce_extent, int threadx_extent) {
int reduce_extent, int group_extent, int contiguous_reduce_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
Expand Down Expand Up @@ -444,9 +473,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
ICHECK(threadx_extent >= 1 && warp_size_ >= 1);

// normal synchronization
while (reduce_align > threadx_extent || reduce_align > warp_size_) {
bool warp_align = group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
while (reduce_align > contiguous_reduce_extent || reduce_align > warp_size_ || !warp_align) {
if (reduce_align == 1) {
break;
}
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
Expand Down Expand Up @@ -534,22 +567,21 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}

// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) {
PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr delta_or_lane) {
PrimExpr pred = const_true(1);
PrimExpr index(0);
PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width};
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
}

// Check if this is a reduction on threadIdx.x and its extent matches
// the warp size.
// Check if we can use warp level reduction.
//
// TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
bool is_warp_reduction(const std::vector<DataType>& types) const {
bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) const {
// Only cuda target supports warp reductions.
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;

Expand All @@ -575,18 +607,25 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
return false;
}

const AttrStmtNode* op = thread_extents_.back();
DCHECK_EQ(op->attr_key, attr::thread_extent);

IterVar iv = Downcast<IterVar>(op->node);
ThreadEntry e;
e.scope = runtime::ThreadScope::Create(iv->thread_tag);
e.extent = 0;
if (auto ptr = op->value.as<IntImmNode>()) {
e.extent = static_cast<int>(ptr->value);
// reduce region must be contiguous.
if (contiguous_reduce_extent != reduce_extent) {
return false;
}

return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1;
// whether reduce_extent and group_extent are vaild for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else { // target_->kind->name == "cuda"
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
if (warp_size_ % reduce_extent == 0) {
return true; // warp size is multiple of reduce extent
} else {
return group_extent == 1 && reduce_extent <= warp_size_;
}
}
}
}

// The target.
Expand Down
68 changes: 68 additions & 0 deletions tests/python/unittest/test_subwarp_reduction_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import numpy as np
from tvm.script import tir as T


@T.prim_func
def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
A = T.match_buffer(a, [1, d1, d2, d3])
B = T.match_buffer(b, [1, d1, d2])

for i, j, k, l in T.grid(1, d1, d2, d3):
with T.block("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = 0.0
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_subwarp_reduction():
def check(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")

# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.sum(axis=-1).astype("float32")
a = tvm.nd.array(a_np, tvm.cuda(0))
b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))

# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)

for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in range(2, 33):
check(d1, d2, d3)


if __name__ == "__main__":
test_cuda_subwarp_reduction()