Skip to content

Commit 6f9d028

Browse files
zhiicsicemelon
authored andcommitted
[Relay][QNN] Add unit test for int8 (#4159)
* [bugfix][codegen] fix casting bug in llvm codegen * update example * retrigger ci * check llvm version
1 parent e0d286a commit 6f9d028

File tree

1 file changed

+59
-24
lines changed

1 file changed

+59
-24
lines changed

tests/python/relay/test_op_qnn_conv2d.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
160160
qnn_output = get_output(qnn_func, golden_inputs)
161161
np.testing.assert_equal(qnn_output, golden_output)
162162

163-
def no_zero_point_test():
163+
def test_no_zero_point():
164164
# uint8 input
165165
data_shape = (2, 1, 2, 4)
166166
data_dtype = 'uint8'
@@ -203,7 +203,7 @@ def no_zero_point_test():
203203
verify(ref_func, qnn_func, data_shape, data_dtype,
204204
kernel_shape, kernel_dtype)
205205

206-
def kernel_zero_point_test():
206+
def test_kernel_zero_point():
207207
# uint8 input
208208
data_shape = (2, 4, 2, 4)
209209
data_dtype = 'uint8'
@@ -247,7 +247,7 @@ def kernel_zero_point_test():
247247
kernel_shape, kernel_dtype)
248248

249249

250-
def input_zero_point_test():
250+
def test_input_zero_point():
251251
# uint8 input
252252
data_shape = (2, 4, 2, 4)
253253
data_dtype = 'uint8'
@@ -290,7 +290,7 @@ def input_zero_point_test():
290290
verify(ref_func, qnn_func, data_shape, data_dtype,
291291
kernel_shape, kernel_dtype)
292292

293-
def both_zero_point_test():
293+
def test_both_zero_point():
294294
# uint8 input
295295
data_shape = (2, 4, 2, 4)
296296
data_dtype = 'uint8'
@@ -333,7 +333,7 @@ def both_zero_point_test():
333333
verify(ref_func, qnn_func, data_shape, data_dtype,
334334
kernel_shape, kernel_dtype)
335335

336-
def layout_test():
336+
def test_layout():
337337
# uint8 input
338338
data_shape = (2, 2, 4, 4) # NHWC
339339
data_dtype = 'uint8'
@@ -378,7 +378,7 @@ def layout_test():
378378

379379

380380

381-
def padding_test():
381+
def test_padding():
382382
# uint8 input
383383
data_shape = (1, 4, 2, 2)
384384
data_dtype = 'uint8'
@@ -421,7 +421,7 @@ def padding_test():
421421
verify(ref_func, qnn_func, data_shape, data_dtype,
422422
kernel_shape, kernel_dtype)
423423

424-
def dilation_test():
424+
def test_dilation():
425425
# uint8 input
426426
data_shape = (2, 4, 4, 4)
427427
data_dtype = 'uint8'
@@ -444,7 +444,7 @@ def dilation_test():
444444
kernel_shape, kernel_dtype)
445445

446446

447-
def const_folding_test():
447+
def test_const_folding():
448448
data_shape = (2, 4, 2, 4)
449449
data_dtype = 'uint8'
450450
kernel_shape = (3, 4, 2, 2)
@@ -470,7 +470,7 @@ def const_folding_test():
470470
folded_func = folded_mod["main"]
471471
assert "reshape" not in folded_func.astext()
472472

473-
def kernel_size_1x1_test():
473+
def test_kernel_size_1x1():
474474
# uint8 input
475475
data_shape = (2, 4, 2, 4)
476476
data_dtype = 'uint8'
@@ -493,7 +493,7 @@ def kernel_size_1x1_test():
493493
verify(ref_func, qnn_func, data_shape, data_dtype,
494494
kernel_shape, kernel_dtype)
495495

496-
def tflite_large_irregular_test():
496+
def test_tflite_large_irregular():
497497
# uint8 input
498498
data_shape = (1, 1024, 1, 1)
499499
data_dtype = 'uint8'
@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
607607
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
608608
np.testing.assert_equal(qnn_output, golden_output)
609609

610-
def broadcast_layout_test():
610+
def test_broadcast_layout():
611611
# Test broadcast support for NHWC layout.
612612
data_shape = (1, 229, 229, 3) # NHWC
613613
data_dtype = 'uint8'
@@ -640,17 +640,52 @@ def broadcast_layout_test():
640640
with relay.build_config(opt_level=3):
641641
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
642642

643+
644+
def test_conv2d_int8():
645+
target = "llvm -mcpu=core-avx2"
646+
if not tvm.module.enabled(target):
647+
print("skip because %s is not enabled..." % target)
648+
return
649+
650+
data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
651+
kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
652+
conv = relay.nn.conv2d(
653+
data,
654+
kernel,
655+
kernel_size=(3, 3),
656+
out_dtype='int32',
657+
data_layout='NHWC',
658+
kernel_layout='HWIO')
659+
func = relay.Function([data, kernel], conv)
660+
661+
with relay.build_config(opt_level=0):
662+
params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
663+
# -mcpu should be specified to avoid the llvm jitting error here:
664+
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
665+
# To use VNNI, we need to specify the micro-architecture that supports
666+
# it, e.g. cascadelake.
667+
graph, lib, params = relay.build(func, target, params=params)
668+
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
669+
mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
670+
mod.set_input(**params)
671+
mod.run()
672+
qnn_output = mod.get_output(0).asnumpy()
673+
golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
674+
np.testing.assert_equal(qnn_output, golden_output)
675+
676+
643677
if __name__ == "__main__":
644-
no_zero_point_test()
645-
input_zero_point_test()
646-
kernel_zero_point_test()
647-
both_zero_point_test()
648-
layout_test()
649-
padding_test()
650-
dilation_test()
651-
const_folding_test()
652-
kernel_size_1x1_test()
653-
tflite_large_irregular_test()
654-
tflite_output_multiplier_greater_than_one()
655-
tflite_anistropic_strides()
656-
broadcast_layout_test()
678+
test_no_zero_point()
679+
test_input_zero_point()
680+
test_kernel_zero_point()
681+
test_both_zero_point()
682+
test_layout()
683+
test_padding()
684+
test_dilation()
685+
test_const_folding()
686+
test_kernel_size_1x1g()
687+
test_tflite_large_irregularg()
688+
test_tflite_output_multiplier_greater_than_one()
689+
test_tflite_anistropic_strides()
690+
test_broadcast_layoutg()
691+
test_conv2d_int8()

0 commit comments

Comments
 (0)