Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 11 additions & 45 deletions tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
from utils import (
skip_if_no_reference_system,
make_module,
get_range_for_dtype_str,
assert_partitioned_function,
assert_no_external_function,
)
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
Expand Down Expand Up @@ -113,21 +119,7 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
Expand Down Expand Up @@ -204,21 +196,7 @@ def test_constant_input_int8(op, input_0, input_1):
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
Expand Down Expand Up @@ -262,13 +240,7 @@ def test_both_scalar_inputs_int8(

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


@skip_if_no_reference_system
Expand All @@ -293,13 +265,7 @@ def test_invalid_parameters(

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
Expand Down
60 changes: 6 additions & 54 deletions tests/python/contrib/test_cmsisnn/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
from utils import (
skip_if_no_reference_system,
make_module,
count_num_calls,
get_range_for_dtype_str,
get_same_padding,
get_conv2d_qnn_params,
make_qnn_relu,
assert_partitioned_function,
assert_no_external_function,
)


Expand Down Expand Up @@ -192,21 +193,7 @@ def test_conv2d_symmetric_padding_int8(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsis-nn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
rng = np.random.default_rng(12345)
Expand Down Expand Up @@ -295,21 +282,7 @@ def test_conv2d_asymmetric_padding_int8(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsis-nn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
rng = np.random.default_rng(12345)
Expand Down Expand Up @@ -413,21 +386,7 @@ def test_depthwise_int8(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsis-nn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
rng = np.random.default_rng(12345)
Expand Down Expand Up @@ -513,14 +472,7 @@ def test_invalid_parameters(
)
orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
Expand Down
26 changes: 4 additions & 22 deletions tests/python/contrib/test_cmsisnn/test_fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
from utils import (
skip_if_no_reference_system,
make_module,
count_num_calls,
get_range_for_dtype_str,
get_same_padding,
get_conv2d_qnn_params,
make_qnn_relu,
assert_partitioned_function,
assert_no_external_function,
)


Expand Down Expand Up @@ -152,21 +153,7 @@ def test_op_int8(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
rng = np.random.default_rng(12345)
Expand Down Expand Up @@ -247,12 +234,7 @@ def test_invalid_parameters(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from utils import (
make_module,
count_num_calls,
get_range_for_dtype_str,
get_same_padding,
get_conv2d_qnn_params,
Expand Down
28 changes: 4 additions & 24 deletions tests/python/contrib/test_cmsisnn/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
from utils import (
skip_if_no_reference_system,
make_module,
count_num_calls,
get_range_for_dtype_str,
get_same_padding,
get_conv2d_qnn_params,
make_qnn_relu,
assert_partitioned_function,
assert_no_external_function,
)


Expand Down Expand Up @@ -106,21 +107,7 @@ def test_op_int8(
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
Expand Down Expand Up @@ -159,14 +146,7 @@ def test_invalid_parameters():

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
Expand Down
27 changes: 4 additions & 23 deletions tests/python/contrib/test_cmsisnn/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from utils import (
skip_if_no_reference_system,
make_module,
count_num_calls,
get_range_for_dtype_str,
assert_partitioned_function,
assert_no_external_function,
)
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
Expand Down Expand Up @@ -77,21 +78,7 @@ def test_op_int8(zero_point, scale):
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
Expand Down Expand Up @@ -142,13 +129,7 @@ def test_invalid_parameters(in_dtype, out_dtype, zero_point, scale, out_zero_poi

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_cmsisnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ def visit_call(self, call):
return counter.count


def assert_partitioned_function(orig_mod, cmsisnn_mod):
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"


def assert_no_external_function(mod):
attrs = [mod[var.name_hint].attrs for var in mod.get_global_vars() if mod[var.name_hint].attrs]
assert not any(attrs), "No function should have an external attribute."


def get_range_for_dtype_str(dtype):
"""
Produces the min,max for a give data type.
Expand Down