Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 59 additions & 24 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)

def no_zero_point_test():
def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def kernel_zero_point_test():
def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype)


def input_zero_point_test():
def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def both_zero_point_test():
def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def layout_test():
def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -378,7 +378,7 @@ def layout_test():



def padding_test():
def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
Expand Down Expand Up @@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def dilation_test():
def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
Expand All @@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype)


def const_folding_test():
def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()

def kernel_size_1x1_test():
def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand All @@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def tflite_large_irregular_test():
def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
Expand Down Expand Up @@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def broadcast_layout_test():
def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -640,17 +640,52 @@ def broadcast_layout_test():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")


def test_conv2d_int8():
target = "llvm -mcpu=core-avx2"
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return

data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
conv = relay.nn.conv2d(
data,
kernel,
kernel_size=(3, 3),
out_dtype='int32',
data_layout='NHWC',
kernel_layout='HWIO')
func = relay.Function([data, kernel], conv)

with relay.build_config(opt_level=0):
params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
# -mcpu should be specified to avoid the llvm jitting error here:
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
# To use VNNI, we need to specify the micro-architecture that supports
# it, e.g. cascadelake.
graph, lib, params = relay.build(func, target, params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
np.testing.assert_equal(qnn_output, golden_output)


if __name__ == "__main__":
no_zero_point_test()
input_zero_point_test()
kernel_zero_point_test()
both_zero_point_test()
layout_test()
padding_test()
dilation_test()
const_folding_test()
kernel_size_1x1_test()
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()
test_no_zero_point()
test_input_zero_point()
test_kernel_zero_point()
test_both_zero_point()
test_layout()
test_padding()
test_dilation()
test_const_folding()
test_kernel_size_1x1g()
test_tflite_large_irregularg()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
test_broadcast_layoutg()
test_conv2d_int8()