Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 24 additions & 59 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)

def test_no_zero_point():
def no_zero_point_test():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_no_zero_point():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def test_kernel_zero_point():
def kernel_zero_point_test():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_kernel_zero_point():
kernel_shape, kernel_dtype)


def test_input_zero_point():
def input_zero_point_test():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -290,7 +290,7 @@ def test_input_zero_point():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def test_both_zero_point():
def both_zero_point_test():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -333,7 +333,7 @@ def test_both_zero_point():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def test_layout():
def layout_test():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_layout():



def test_padding():
def padding_test():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_padding():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def test_dilation():
def dilation_test():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
Expand All @@ -444,7 +444,7 @@ def test_dilation():
kernel_shape, kernel_dtype)


def test_const_folding():
def const_folding_test():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -470,7 +470,7 @@ def test_const_folding():
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()

def test_kernel_size_1x1():
def kernel_size_1x1_test():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand All @@ -493,7 +493,7 @@ def test_kernel_size_1x1():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def test_tflite_large_irregular():
def tflite_large_irregular_test():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
Expand Down Expand Up @@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def test_broadcast_layout():
def broadcast_layout_test():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -640,52 +640,17 @@ def test_broadcast_layout():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")


def test_conv2d_int8():
target = "llvm -mcpu=core-avx2"
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return

data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
conv = relay.nn.conv2d(
data,
kernel,
kernel_size=(3, 3),
out_dtype='int32',
data_layout='NHWC',
kernel_layout='HWIO')
func = relay.Function([data, kernel], conv)

with relay.build_config(opt_level=0):
params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
# -mcpu should be specified to avoid the llvm jitting error here:
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
# To use VNNI, we need to specify the micro-architecture that supports
# it, e.g. cascadelake.
graph, lib, params = relay.build(func, target, params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
np.testing.assert_equal(qnn_output, golden_output)


if __name__ == "__main__":
test_no_zero_point()
test_input_zero_point()
test_kernel_zero_point()
test_both_zero_point()
test_layout()
test_padding()
test_dilation()
test_const_folding()
test_kernel_size_1x1g()
test_tflite_large_irregularg()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
test_broadcast_layoutg()
test_conv2d_int8()
no_zero_point_test()
input_zero_point_test()
kernel_zero_point_test()
both_zero_point_test()
layout_test()
padding_test()
dilation_test()
const_folding_test()
kernel_size_1x1_test()
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()