Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,10 +1079,133 @@ def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7,
)


def test_cpu_gmm():
# fmt: off
@T.prim_func
def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1 in T.grid(1, 4, 2, 1, 1, 8):
for i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z_global[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z_global[b, i, j] = T.float32(0)
Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 32, 8):
with T.block("Z_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
v2 = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + ax2)
T.reads(Z_global[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_global[v0, v1, v2]
@T.prim_func
def gmm_1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
for i0_0, i1_0, i2_0 in T.grid(1, 4, 2):
for i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z_global[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z_global[b, i, j] = T.float32(0)
Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 32, 64):
with T.block("Z_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
v2 = T.axis.spatial(128, i2_0 * 64 + ax2)
T.reads(Z_global[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_global[v0, v1, v2]
@T.prim_func
def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z[b, i, j] = T.float32(0)
Z[b, i, j] = Z[b, i, j] + X[b, i, k] * Y[b, k, j]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
mod = create_te_workload("GMM", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[gmm_0, gmm_1, gmm_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
test_cpu_c3d()
test_cpu_cap()
test_cpu_dep()
test_cpu_dil()
test_cpu_gmm()
82 changes: 82 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,10 +572,92 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7,
)


def test_cuda_gmm():
# fmt: off
@T.prim_func
def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":1024})
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(1, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(32, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(2, thread="threadIdx.x"):
for i3_0 in T.serial(1):
for ax0_ax1_ax2_fused in T.serial(16384):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":2})
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused in T.serial(16384):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":1})
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1):
with T.block("Z"):
b = T.axis.spatial(1, i0_4 + i0_3)
i = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + i1_3 * 2 + i1_4)
j = T.axis.spatial(128, i2_4 + i0_2_i1_2_i2_2_fused * 64 + i2_3)
k = T.axis.reduce(128, i3_0 * 128 + i3_1 * 4 + i3_2)
T.reads(X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
Z_local[b, i, j] = T.float32(0)
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 64):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + ax1)
v2 = T.axis.spatial(128, i0_2_i1_2_i2_2_fused * 64 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [1, 32, 1, 2, 2]),
("SamplePerfectTile", [1, 1, 2, 64, 1]),
("SamplePerfectTile", [1, 32, 4]),
("SampleCategorical", 1),
("SampleCategorical", 0),
("SampleCategorical", 4),
]
mod = create_te_workload("GMM", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[gmm_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
test_cuda_c3d()
test_cuda_cap()
test_cuda_dep()
test_cuda_dil()
test_cuda_gmm()