Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/dlight/gpu/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A rule for reduction. """
"""A rule for reduction."""
# TODO: combine reduction rule and general reduction rule into one file.
from typing import List, Mapping, Optional, Tuple, Union

Expand Down Expand Up @@ -47,6 +47,10 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
return buffer_store.value.b


def _has_reduction_loop(block_info):
return any([info.kind == "R" for info in block_info.iters])


class Reduction(GPUScheduleRule):
"""A rule for Reduction."""

Expand Down Expand Up @@ -79,6 +83,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
# Step 1. Check reduction block
if (
(not block_info.is_reduction())
or (not _has_reduction_loop(block_info))
or len(block_stmt.writes) != 1
or _get_reduction_expr(block_stmt) is None
):
Expand Down
26 changes: 26 additions & 0 deletions tests/python/dlight/test_gpu_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,5 +1152,31 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((
assert_structural_equal(mod, Expected)


def test_no_reduction_loop_check():
# The normalized prime func will not contain a reduction loop since its extent is one.
# This checks that the Reduction schedule is correctly not applied in this case
# fmt: off
@I.ir_module
class Before:
@T.prim_func(private=True)
def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2])
T.writes(matmul[v_i0, v_i1, v_i2])
with T.init():
matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2]
# fmt: on

target = Target("nvidia/geforce-rtx-3090-ti")
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint: disable=not-callable
assert_structural_equal(mod, Before)


if __name__ == "__main__":
tvm.testing.main()