diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index fcdd948260ea..25e8ffe94155 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -195,11 +195,13 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo cfg.define_split("tile_ic", in_channel, num_outputs=2) cfg.define_split("tile_oc", num_filter, num_outputs=2) - cfg.define_split( - "tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64, policy="verbose" - ) + if isinstance(ow, (tvm.tir.IntImm, int)): + cfg.define_split( + "tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64, policy="verbose" + ) if is_kernel_1x1: - cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) + if isinstance(oh, (tvm.tir.IntImm, int)): + cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) else: cfg.define_knob("unroll_kw", [True, False]) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 86af1ad1c38b..443637276e24 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -548,6 +548,8 @@ def verify_any_conv2d( data_layout="NCHW", kernel_layout="OIHW", use_cudnn=False, + targets=None, + disable_targets=None, ): mod = tvm.IRModule() dtype = "float32" @@ -567,11 +569,17 @@ def verify_any_conv2d( data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) - targets = None if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True): targets = [("cuda -libs=cudnn", tvm.cuda(0))] - check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets) + check_result( + [data_np, kernel_np], + mod, + ref_out_shape, + assert_shape=True, + targets=targets, + disable_targets=disable_targets, + ) # TODO(@kevinthesun): Support dynamic input height and width. @@ -627,6 +635,26 @@ def test_any_conv2d(): data_layout="NHWC", kernel_layout="HWIO", ) + verify_any_conv2d( + (relay.Any(), 64, relay.Any(), relay.Any()), + (64, 64, 3, 3), + (1, 1), + (1, 1), + (1, 1), + (1, 64, 224, 224), + (1, 64, 224, 224), + targets=[("llvm", tvm.cpu(0))], + ) + verify_any_conv2d( + (relay.Any(), 64, relay.Any(), relay.Any()), + (64, 64, 1, 1), + (1, 1), + (0, 0), + (1, 1), + (1, 64, 224, 224), + (1, 64, 224, 224), + targets=[("llvm", tvm.cpu(0))], + ) class TestAnyConv2dNCHWc: