diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 9851bb9800fa..4faaa1cab94a 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""A rule for reduction. """ +"""A rule for reduction.""" # TODO: combine reduction rule and general reduction rule into one file. from typing import List, Mapping, Optional, Tuple, Union @@ -47,6 +47,10 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b +def _has_reduction_loop(block_info): + return any([info.kind == "R" for info in block_info.iters]) + + class Reduction(GPUScheduleRule): """A rule for Reduction.""" @@ -79,6 +83,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- # Step 1. Check reduction block if ( (not block_info.is_reduction()) + or (not _has_reduction_loop(block_info)) or len(block_stmt.writes) != 1 or _get_reduction_expr(block_stmt) is None ): diff --git a/tests/python/dlight/test_gpu_reduction.py b/tests/python/dlight/test_gpu_reduction.py index 1ce57eb53d22..0a74df70c084 100644 --- a/tests/python/dlight/test_gpu_reduction.py +++ b/tests/python/dlight/test_gpu_reduction.py @@ -1152,5 +1152,31 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer(( assert_structural_equal(mod, Expected) +def test_no_reduction_loop_check(): + # The normalized prime func will not contain a reduction loop since its extent is one. + # This checks that the Reduction schedule is correctly not applied in this case + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): + T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2]) + T.writes(matmul[v_i0, v_i1, v_i2]) + with T.init(): + matmul[v_i0, v_i1, v_i2] = T.float16(0.0) + matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2] + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint: disable=not-callable + assert_structural_equal(mod, Before) + + if __name__ == "__main__": tvm.testing.main()