diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index f6417acbe613..d08a88201d2e 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -28,7 +28,13 @@ from tvm import relay from tvm.relay.op.contrib import cmsisnn -from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str +from utils import ( + skip_if_no_reference_system, + make_module, + get_range_for_dtype_str, + assert_partitioned_function, + assert_no_external_function, +) from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, @@ -113,21 +119,7 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsisnn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output in_min, in_max = get_range_for_dtype_str(dtype) @@ -204,21 +196,7 @@ def test_constant_input_int8(op, input_0, input_1): cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsisnn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output in_min, in_max = get_range_for_dtype_str(dtype) @@ -262,13 +240,7 @@ def test_both_scalar_inputs_int8( orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) - - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) @skip_if_no_reference_system @@ -293,13 +265,7 @@ def test_invalid_parameters( orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) - - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index e9eb6fb3a145..ddcb2c10bff7 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -34,11 +34,12 @@ from utils import ( skip_if_no_reference_system, make_module, - count_num_calls, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, + assert_partitioned_function, + assert_no_external_function, ) @@ -192,21 +193,7 @@ def test_conv2d_symmetric_padding_int8( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsis-nn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output rng = np.random.default_rng(12345) @@ -295,21 +282,7 @@ def test_conv2d_asymmetric_padding_int8( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsis-nn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output rng = np.random.default_rng(12345) @@ -413,21 +386,7 @@ def test_depthwise_int8( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsis-nn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output rng = np.random.default_rng(12345) @@ -513,14 +472,7 @@ def test_invalid_parameters( ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) - - # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py index 42b36a77b77f..9d4b1e155122 100644 --- a/tests/python/contrib/test_cmsisnn/test_fully_connected.py +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -34,11 +34,12 @@ from utils import ( skip_if_no_reference_system, make_module, - count_num_calls, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, + assert_partitioned_function, + assert_no_external_function, ) @@ -152,21 +153,7 @@ def test_op_int8( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsisnn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output rng = np.random.default_rng(12345) @@ -247,12 +234,7 @@ def test_invalid_parameters( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cmsisnn/test_generate_constants.py b/tests/python/contrib/test_cmsisnn/test_generate_constants.py index c5e97253d94b..1f6c76381580 100644 --- a/tests/python/contrib/test_cmsisnn/test_generate_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_generate_constants.py @@ -26,7 +26,6 @@ from utils import ( make_module, - count_num_calls, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index 1c440b1e1de4..ee4f5c4aea4d 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -34,11 +34,12 @@ from utils import ( skip_if_no_reference_system, make_module, - count_num_calls, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, + assert_partitioned_function, + assert_no_external_function, ) @@ -106,21 +107,7 @@ def test_op_int8( cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsisnn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output in_min, in_max = get_range_for_dtype_str(dtype) @@ -159,14 +146,7 @@ def test_invalid_parameters(): orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) - - # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index 5d1c2fdcc8c1..c3617cce15d4 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -30,8 +30,9 @@ from utils import ( skip_if_no_reference_system, make_module, - count_num_calls, get_range_for_dtype_str, + assert_partitioned_function, + assert_no_external_function, ) from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, @@ -77,21 +78,7 @@ def test_op_int8(zero_point, scale): cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) # validate pattern matching - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert any(attrs), "At least one function with external attributes was expected." - - compilers = [ - key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() - ] - assert any(compilers), "Module does not contain function for cmsisnn target." - - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output in_min, in_max = get_range_for_dtype_str(dtype) @@ -142,13 +129,7 @@ def test_invalid_parameters(in_dtype, out_dtype, zero_point, scale, out_zero_poi orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) - - attrs = [ - cmsisnn_mod[var.name_hint].attrs - for var in cmsisnn_mod.get_global_vars() - if cmsisnn_mod[var.name_hint].attrs - ] - assert not any(attrs), "No function should have an external attribute." + assert_no_external_function(cmsisnn_mod) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 3575284eb1c0..e23c0a66bafb 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -54,6 +54,29 @@ def visit_call(self, call): return counter.count +def assert_partitioned_function(orig_mod, cmsisnn_mod): + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + +def assert_no_external_function(mod): + attrs = [mod[var.name_hint].attrs for var in mod.get_global_vars() if mod[var.name_hint].attrs] + assert not any(attrs), "No function should have an external attribute." + + def get_range_for_dtype_str(dtype): """ Produces the min,max for a give data type.