diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 6f7c09cdcf2d..1c6aa161e473 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -218,6 +218,33 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { int reduce_extent, group_extent; PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); PrimExpr group_index = FlattenThread(vpar, &group_extent); + + // the longest contiguous reduce extent after flattening + int contiguous_reduce_extent = 1; + std::vector> block_threads; // tuple(dim_index, extent, is_reduce) + for (const ThreadEntry& thr : vred) { + if (thr.scope.rank == 1) { // threadIdx + block_threads.emplace_back(thr.scope.dim_index, thr.extent, true); + } + } + for (const ThreadEntry& thr : vpar) { + if (thr.scope.rank == 1) { // threadIdx + block_threads.emplace_back(thr.scope.dim_index, thr.extent, false); + } + } + // sort according to dim_index + std::sort(block_threads.begin(), block_threads.end()); + for (auto&& thr_attr : block_threads) { + int dim_index, extent; + bool is_reduce; + std::tie(dim_index, extent, is_reduce) = thr_attr; + if (is_reduce) { + contiguous_reduce_extent *= extent; + } else { + break; + } + } + std::vector seq; std::vector shared_bufs(size); std::vector local_vars; @@ -238,9 +265,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // broadcast results from lane 0 to all other lanes and store // the final reduction result to the proper location. // - if (is_warp_reduction(types)) { - // TODO(tvm-team) sub-warp reduction support. - ICHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction"; + if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) { + ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction"; // // This is the index to the reduction variable, one reduction // variable per warp. Local scope seems easier to reason without @@ -269,6 +295,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr pred = const_true(1); PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); + if (group_extent > 1) { + mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index)); + } seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -277,7 +306,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit reductions within a warp. - for (int offset = warp_size_ / 2; offset > 0; offset /= 2) { + int start_offset = 1; + while (start_offset * 2 < reduce_extent) { + start_offset *= 2; + } + for (int offset = start_offset; offset > 0; offset /= 2) { // Load reduction values, no synchronization needed. Array a, b; for (size_t i = 0; i < size; ++i) { @@ -323,13 +356,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Broadcast the reduction result from lane 0 to all other lanes. // This avoids to emit predicated stores, as all threads are - // uniformmly writting the same result. + // uniformly writting the same result. // for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); PrimExpr val = Load(types[i], var, index, pred); - PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0); + PrimExpr splat = + WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, reduce_extent * group_index); seq.push_back(Store(var, splat, index, pred)); } @@ -346,7 +380,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { warp_allocs_.insert(node.get()); } } else { - int threadx_extent = 1; if (reduce_extent == 1) { // special case, no reduction is needed. std::vector stores(size); @@ -357,10 +390,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } return SeqStmt::Flatten(stores); } - // Whether the threadIdx.x is involved in reduction. - if (vred[0].scope.dim_index == 0) { - threadx_extent = vred[0].extent; - } // This sync is necessary because there might be incomplete read of // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); @@ -372,7 +401,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, - reduce_extent, threadx_extent)); + reduce_extent, group_extent, contiguous_reduce_extent)); for (size_t idx = 0; idx < size; ++idx) { ICHECK(!load_remap_.count(buffers[idx])); PrimExpr pred = const_true(types[idx].lanes()); @@ -402,7 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, - int reduce_extent, int threadx_extent) { + int reduce_extent, int group_extent, int contiguous_reduce_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -444,9 +473,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } - ICHECK(threadx_extent >= 1 && warp_size_ >= 1); + // normal synchronization - while (reduce_align > threadx_extent || reduce_align > warp_size_) { + bool warp_align = group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0; + while (reduce_align > contiguous_reduce_extent || reduce_align > warp_size_ || !warp_align) { + if (reduce_align == 1) { + break; + } reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < reduce_align; seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); @@ -534,22 +567,21 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) { + PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); - Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; + Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); } - // Check if this is a reduction on threadIdx.x and its extent matches - // the warp size. + // Check if we can use warp level reduction. // - // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads. // Note: The ROCm backend will only have warp reductions for now. // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). - bool is_warp_reduction(const std::vector& types) const { + bool is_warp_reduction(const std::vector& types, int group_extent, int reduce_extent, + int contiguous_reduce_extent) const { // Only cuda target supports warp reductions. if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false; @@ -575,18 +607,25 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return false; } - const AttrStmtNode* op = thread_extents_.back(); - DCHECK_EQ(op->attr_key, attr::thread_extent); - - IterVar iv = Downcast(op->node); - ThreadEntry e; - e.scope = runtime::ThreadScope::Create(iv->thread_tag); - e.extent = 0; - if (auto ptr = op->value.as()) { - e.extent = static_cast(ptr->value); + // reduce region must be contiguous. + if (contiguous_reduce_extent != reduce_extent) { + return false; } - return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1; + // whether reduce_extent and group_extent are vaild for warp reduction. + if (target_->kind->name == "rocm") { + return reduce_extent == warp_size_; + } else { // target_->kind->name == "cuda" + if (reduce_extent == 1) { + return false; // no need to warp reduce + } else { + if (warp_size_ % reduce_extent == 0) { + return true; // warp size is multiple of reduce extent + } else { + return group_extent == 1 && reduce_extent <= warp_size_; + } + } + } } // The target. diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py b/tests/python/unittest/test_subwarp_reduction_cuda.py new file mode 100644 index 000000000000..8778c75f5699 --- /dev/null +++ b/tests/python/unittest/test_subwarp_reduction_cuda.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +import numpy as np +from tvm.script import tir as T + + +@T.prim_func +def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None: + A = T.match_buffer(a, [1, d1, d2, d3]) + B = T.match_buffer(b, [1, d1, d2]) + + for i, j, k, l in T.grid(1, d1, d2, d3): + with T.block("reduce"): + vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) + with T.init(): + B[vi, vj, vk] = 0.0 + B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl] + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_cuda_subwarp_reduction(): + def check(d1: int, d2: int, d3: int): + _, _, _d1, _d2, _d3 = reduce.params + mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) + sch = tvm.tir.Schedule(mod) + blk = sch.get_block("reduce") + i, j, k, l = sch.get_loops(blk) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.z") + sch.bind(k, "threadIdx.y") + sch.bind(l, "threadIdx.x") + f = tvm.build(sch.mod["main"], target="cuda") + + # prepare input and output array + a_np = np.random.rand(1, d1, d2, d3).astype("float32") + b_np = a_np.sum(axis=-1).astype("float32") + a = tvm.nd.array(a_np, tvm.cuda(0)) + b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0)) + + # launch kernel + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6) + + for d1 in range(1, 5): + for d2 in range(1, 5): + for d3 in range(2, 33): + check(d1, d2, d3) + + +if __name__ == "__main__": + test_cuda_subwarp_reduction()