Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,188 @@ def dep_2(placeholder: T.Buffer[(1, 112, 112, 32), "float32"], placeholder_1: T.
)


def test_cpu_dil():
# fmt: off
@T.prim_func
def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 109, 109, 64), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
conv2d_nhwc_global = T.alloc_buffer([1, 109, 109, 64], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 109, 1, 4, 1, 1, 1, 2):
for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(230, i1_0 * 2 + ax1)
i2 = T.axis.spatial(230, ax2)
i3 = T.axis.spatial(3, ax3)
T.reads(inputs[i0, i1 - 3, i2 - 3, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
rh = T.axis.reduce(7, i4_0 + i4_1)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 8):
with T.block("conv2d_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(109, i1_0 + ax1)
v2 = T.axis.spatial(109, ax2)
v3 = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + ax3)
T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def dil_1(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 109, 109, 64), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
conv2d_nhwc_global = T.alloc_buffer([1, 109, 109, 64], dtype="float32")
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 109, 1, 4):
for i0_1, i1_1, i2_1, i3_1, i4_0 in T.grid(1, 1, 1, 2, 7):
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 229, 3):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(230, i1_0 * 2 + i4_0 * 2 + ax1)
i2 = T.axis.spatial(230, ax2)
i3 = T.axis.spatial(3, ax3)
T.reads(inputs[i0, i1 - 3, i2 - 3, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
rh = T.axis.reduce(7, i4_0 + i4_1)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 16):
with T.block("conv2d_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(109, i1_0 + ax1)
v2 = T.axis.spatial(109, ax2)
v3 = T.axis.spatial(64, i3_0 * 16 + ax3)
T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 109, 109, 64), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
for i0_0, i1_0 in T.grid(1, 109):
for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(230, i1_0 * 2 + ax1)
i2 = T.axis.spatial(230, ax2)
i3 = T.axis.spatial(3, ax3)
T.reads(inputs[i0, i1 - 3, i2 - 3, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
rh = T.axis.reduce(7, i4_0 + i4_1)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc[n, h, w, co] = T.float32(0)
conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight[rh, rw, rc, co]

# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [109, 1, 1, 1]),
("SamplePerfectTile", [1, 1, 109, 1]),
("SamplePerfectTile", [4, 2, 8, 1]),
("SamplePerfectTile", [7, 1]),
("SamplePerfectTile", [1, 7]),
("SamplePerfectTile", [1, 3]),
("SampleCategorical", 2),
("SampleComputeLocation", 7),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [109, 1, 1, 1]),
("SamplePerfectTile", [1, 1, 109, 1]),
("SamplePerfectTile", [4, 2, 8, 1]),
("SamplePerfectTile", [7, 1]),
("SamplePerfectTile", [1, 7]),
("SamplePerfectTile", [1, 3]),
("SampleCategorical", 0),
("SampleComputeLocation", 8),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [109, 1, 1, 1]),
("SamplePerfectTile", [1, 1, 109, 1]),
("SamplePerfectTile", [4, 2, 8, 1]),
("SamplePerfectTile", [7, 1]),
("SamplePerfectTile", [1, 7]),
("SamplePerfectTile", [1, 3]),
("SampleCategorical", 0),
("SampleComputeLocation", 1),
]
mod = create_te_workload("DIL", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[dil_0, dil_1, dil_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
test_cpu_c3d()
test_cpu_cap()
test_cpu_dep()
test_cpu_dil()
89 changes: 89 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,98 @@ def dep_0(placeholder: T.Buffer[(1, 112, 112, 32), "float32"], placeholder_1: T.
)


def test_cuda_dil():
# fmt: off
@T.prim_func
def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 109, 109, 64), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":512})
conv2d_nhwc_local = T.alloc_buffer([1, 109, 109, 64], dtype="float32", scope="local")
PadInput_shared = T.alloc_buffer([1, 230, 230, 3], dtype="float32", scope="shared")
weight_shared = T.alloc_buffer([7, 7, 3, 64], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(218, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(109, thread="vthread.x"):
for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(1, thread="threadIdx.x"):
for i4_0, i5_0, i6_0 in T.grid(7, 7, 3):
for ax0_ax1_ax2_ax3_fused in T.serial(217):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0)
v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217)
v3 = T.axis.spatial(3, i6_0 + 0)
T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
T.writes(PadInput_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":2})
PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0), dtype="float32")
for ax0_ax1_ax2_ax3_fused in T.serial(32):
with T.block("weight_shared"):
v0, v1, v2 = T.axis.remap("SSS", [i4_0, i5_0, i6_0])
v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
h = T.axis.spatial(109, i0_0_i1_0_i2_0_i3_0_fused % 218 // 2 + 0 + 0 + i1_3 + i1_4)
w = T.axis.spatial(109, 0 * 109 + i0_1_i1_1_i2_1_i3_1_fused % 109 + 0 + i2_3 + i2_4)
co = T.axis.spatial(64, ((i0_0_i1_0_i2_0_i3_0_fused % 2 + 0 + 0) * 8 + i3_3) * 4 + i3_4)
rh = T.axis.reduce(7, i4_0 + i4_1 + i4_2)
rw = T.axis.reduce(7, i5_0 + i5_1 + i5_2)
rc = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
T.reads(PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co])
T.writes(conv2d_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
conv2d_nhwc_local[n, h, w, co] = T.float32(0)
conv2d_nhwc_local[n, h, w, co] = conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc] * weight_shared[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 32):
with T.block("conv2d_nhwc_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(109, i0_0_i1_0_i2_0_i3_0_fused // 2 + ax1)
v2 = T.axis.spatial(109, i0_1_i1_1_i2_1_i3_1_fused + ax2)
v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3)
T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [109, 1, 1, 1, 1]),
("SamplePerfectTile", [1, 109, 1, 1, 1]),
("SamplePerfectTile", [2, 1, 1, 8, 4]),
("SamplePerfectTile", [7, 1, 1]),
("SamplePerfectTile", [7, 1, 1]),
("SamplePerfectTile", [3, 1, 1]),
("SampleCategorical", 1),
("SampleCategorical", 3),
("SampleCategorical", 3),
]
mod = create_te_workload("DIL", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[dil_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
test_cuda_c3d()
test_cuda_cap()
test_cuda_dep()
test_cuda_dil()