Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = repl->buffer_var;
write_ptr->dtype = repl->dtype;
write_ptr->extents = repl->extents;
write_ptr->condition = repl->condition;
write_ptr->buffer_var = buf->data;
write_ptr->dtype = buf->dtype;
write_ptr->extents = buf->shape;
write_ptr->condition = const_true(buf->dtype.lanes());

if (buf.scope() == "shared") {
// Use volatile access to shared buffer.
write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body);
}
}
return std::move(node);
}
Expand Down Expand Up @@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
if (n_warps < warp_size_) {
mask = mask & (((1 << n_warps) - 1) << group_index);
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
&seq);
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

// 5. Create shared memory buffer(s) of `group_extent` elements, storing
Expand All @@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
Expand All @@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
load_remap_[buffers[i]->data.get()] = reduce_results[i];

auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
Expand All @@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared");
shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), "shared");
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
Expand All @@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] =
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
Expand Down Expand Up @@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
std::unordered_map<const VarNode*, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32)

@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_subwarp_reduction():
def test_allreduce_cuda():
def check_sum(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
Expand Down Expand Up @@ -95,10 +95,12 @@ def check_max(d1: int, d2: int, d3: int):

for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in range(2, 33):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 > 1024:
continue
check_sum(d1, d2, d3)
check_max(d1, d2, d3)


if __name__ == "__main__":
test_cuda_subwarp_reduction()
test_allreduce_cuda()
24 changes: 14 additions & 10 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
Expand Down Expand Up @@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
Expand Down Expand Up @@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 4)
red_result = T.allocate([4], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
with T.attr(
Expand Down Expand Up @@ -585,23 +588,23 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((4,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
Expand Down Expand Up @@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
threadIdx_y = T.launch_thread("threadIdx.y", 2)
in_thread_B = T.allocate([1], "float32", "local")
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 512)
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
in_thread_B_1[0] = T.float32(0)
Expand Down Expand Up @@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 32:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
Expand All @@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


if __name__ == "__main__":
Expand Down