-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
When using MetaSchedule to tune a conv3d ncdhw workload, the tuning result cannot pass wellformed check and caused the following error
ValueError: Invalid use of undefined variable C_s0 at <root>.body.block.body.body.body.seq[0].body.seq[0].body.body.body.body.body.body.body.body.block.match_buffers[0].buffer.strides[0].
To reproduce, you can do tuning with the following code with tune set to True and then apply the database to reproduce the error. To directly get the tuned workload TIR and run test, you can also try this script.
Thanks to @jwfromm for reporting this issue.
import tvm
from tvm.script import tir as T, ir as I
from tvm import meta_schedule as ms
from tvm.tir.tensor_intrin import *
@T.prim_func(private=True)
def func(silu33: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16"), down_blocks_2_resnets_0_temporal_res_block_conv1_weight: T.Buffer((T.int64(1280), T.int64(1280), T.int64(3), T.int64(1), T.int64(1)), "float16"), lv113: T.Buffer((T.int64(1), T.int64(1280), T.int64(1), T.int64(1), T.int64(1)), "float16"), permute_dims152: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(1), T.int64(1)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(27), T.int64(18), T.int64(32)), "float16")
conv3d_ncdhw_intermediate = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")
T_add_intermediate = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")
for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(1280), T.int64(27), T.int64(18), T.int64(32)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(silu33[v_i0, v_i1, v_i2 - T.int64(1), v_i3, v_i4])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3, v_i4])
pad_temp[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(26), silu33[v_i0, v_i1, v_i2 - T.int64(1), v_i3, v_i4], T.float16(0))
for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32), T.int64(1280), T.int64(3), T.int64(1), T.int64(1)):
with T.block("conv3d_ncdhw"):
v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx, v_zz + v_rz], down_blocks_2_resnets_0_temporal_res_block_conv1_weight[v_ff, v_rc, v_ry, v_rx, v_rz])
T.writes(conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz])
with T.init():
conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0)
conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] = conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx, v_zz + v_rz] * down_blocks_2_resnets_0_temporal_res_block_conv1_weight[v_ff, v_rc, v_ry, v_rx, v_rz]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(conv3d_ncdhw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], lv113[T.int64(0), v_ax1, T.int64(0), T.int64(0), T.int64(0)])
T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = conv3d_ncdhw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + lv113[T.int64(0), v_ax1, T.int64(0), T.int64(0), T.int64(0)]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)):
with T.block("T_add_1"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], permute_dims152[v_ax0, v_ax1, v_ax2, T.int64(0), T.int64(0)])
T.writes(T_add_intermediate_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add_intermediate_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + permute_dims152[v_ax0, v_ax1, v_ax2, T.int64(0), T.int64(0)]
if __name__ == "__main__":
func.show()
target = tvm.target.Target("nvidia/nvidia-a10g")
tune = False
if tune:
db = ms.tune_tir(func, target=target, work_dir="./temp", max_trials_global=500)
else:
db = ms.database.JSONDatabase(work_dir="./temp")
mod = tvm.ir.IRModule({"main": func.with_attrs({"global_symbol": "main"})})
tuned_mod = db.query_ir_module(mod=mod, target=target, workload_name="main")
tuned_mod.show()
tvm.build(tuned_mod, target=target)
tvm.tir.analysis.verify_well_formed(tuned_mod)Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug