Skip to content

Commit bedc772

Browse files
authored
[microNPU] Add support for SPLIT and SPLIT_V (#9621)
Both, SPLIT and SPLIT_V get lowered to relay.split and in the legalization the Relay split gets turned into strided slices. This patch adds the pattern and legalizer to enable offloading the TFLite's splits to the NPU.
1 parent 510f7c6 commit bedc772

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,25 @@ def callback(
108108
return relay.Tuple(strided_slices)
109109

110110

111+
class PartitionedSplitRewriter(DFPatternCallback):
112+
"""This pass brings the split out of the partitioned function"""
113+
114+
def __init__(self):
115+
super().__init__(require_type=True, rewrite_once=True)
116+
self.pattern = (
117+
wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name})
118+
)(wildcard())
119+
120+
def callback(
121+
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
122+
) -> tvm.relay.Expr:
123+
split_input = post.args[0]
124+
split_params = ethosu_patterns.SplitParams(post.op.body)
125+
indices_or_sections = split_params.indices_or_sections
126+
axis = split_params.axis
127+
return relay.op.split(split_input, indices_or_sections, axis=axis).astuple()
128+
129+
111130
@ir.transform.module_pass(opt_level=1)
112131
class LegalizeSplit:
113132
"""This is the pass that wraps SplitRewriter"""
@@ -116,6 +135,7 @@ def transform_module(
116135
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
117136
) -> tvm.ir.IRModule:
118137
for global_var, func in mod.functions.items():
138+
func = rewrite(PartitionedSplitRewriter(), func)
119139
func = rewrite(SplitRewriter(), func)
120140
mod.update_func(global_var, func)
121141
return mod

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,44 @@ def concat_pattern():
11071107
return concat
11081108

11091109

1110+
class SplitParams:
1111+
"""
1112+
This class will parse a call to a ethos-u.split composite function
1113+
and extract the parameter information.
1114+
"""
1115+
1116+
composite_name = "ethos-u.split"
1117+
1118+
def __init__(self, func_body):
1119+
self.split = func_body
1120+
self.input = TensorParams(func_body.args[0])
1121+
self.axis = func_body.attrs.axis
1122+
self.indices_or_sections = self.convert_indices_or_sections(
1123+
func_body.attrs.indices_or_sections
1124+
)
1125+
1126+
def convert_indices_or_sections(self, indices_or_sections):
1127+
# split_v
1128+
if isinstance(indices_or_sections, tvm.ir.container.Array):
1129+
values = [i.value for i in indices_or_sections]
1130+
# split
1131+
else:
1132+
values = indices_or_sections.value
1133+
return values
1134+
1135+
def is_valid(self):
1136+
"""Checks whether split has compatible attributes with the hardware"""
1137+
if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]):
1138+
return False
1139+
return True
1140+
1141+
1142+
def split_pattern():
1143+
"Create the pattern for split"
1144+
split = is_op("split")(wildcard())
1145+
return split
1146+
1147+
11101148
@register_pattern_table("ethos-u")
11111149
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
11121150
return [
@@ -1187,6 +1225,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
11871225
sigmoid_pattern(),
11881226
lambda pat: SigmoidParams(pat).is_valid(),
11891227
),
1228+
(
1229+
SplitParams.composite_name,
1230+
split_pattern(),
1231+
lambda pat: SplitParams(pat).is_valid(),
1232+
),
11901233
]
11911234

11921235

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,5 +929,27 @@ def sigmoid_function(x):
929929
_compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
930930

931931

932+
# This codegen test checks both, split and split_v
933+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
934+
@pytest.mark.parametrize(
935+
"ifm_shape, num_or_size_splits, axis",
936+
[
937+
((1, 4, 6, 8), (1, 3, 4), 3),
938+
((4, 6, 8), 2, 0),
939+
((50,), 25, 0),
940+
((5, 11), 1, 1),
941+
((13,), (13,), 0),
942+
((22, 7), (4, -1), 1),
943+
],
944+
)
945+
def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis):
946+
@tf.function
947+
def split_func(x):
948+
op = tf.split(x, num_or_size_splits, axis=axis)
949+
return op
950+
951+
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)
952+
953+
932954
if __name__ == "__main__":
933955
pytest.main([__file__])

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,5 +1344,163 @@ def representative_dataset():
13441344
assert tuple(func_body.args[1].checked_type.shape) == (256,)
13451345

13461346

1347+
@pytest.mark.parametrize(
1348+
"ifm_shape, num_or_size_splits, axis",
1349+
[
1350+
((1, 4, 6, 8), 3, 2),
1351+
((4, 6, 8), 2, 0),
1352+
((5, 15), 3, 1),
1353+
((3, 7), 1, 1),
1354+
((100,), 25, 0),
1355+
],
1356+
)
1357+
def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis):
1358+
dtype = "int8"
1359+
1360+
def create_tflite_graph():
1361+
class Model(tf.Module):
1362+
@tf.function
1363+
def tf_function(self, x, num_or_size_splits, axis):
1364+
op = tf.split(x, num_or_size_splits, axis=axis)
1365+
return op
1366+
1367+
model = Model()
1368+
concrete_func = model.tf_function.get_concrete_function(
1369+
tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis
1370+
)
1371+
1372+
def representative_dataset():
1373+
for _ in range(100):
1374+
data = np.random.rand(*tuple(ifm_shape))
1375+
yield [data.astype(np.float32)]
1376+
1377+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
1378+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
1379+
converter.representative_dataset = representative_dataset
1380+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
1381+
converter.inference_input_type = tf.int8
1382+
converter.inference_output_type = tf.int8
1383+
tflite_model = converter.convert()
1384+
1385+
return tflite_model
1386+
1387+
def verify(ext_func):
1388+
# dig out the split
1389+
single_output_split = num_or_size_splits == 1
1390+
split = (
1391+
ext_func.body.tuple_value
1392+
if single_output_split
1393+
else ext_func.body.args[0][0].args[0].tuple_value
1394+
)
1395+
assert split.op.name == "split"
1396+
1397+
# Split is specified by number of equal chunks
1398+
assert split.attrs.indices_or_sections == num_or_size_splits
1399+
1400+
assert split.attrs.axis == axis
1401+
1402+
tflite_graph = create_tflite_graph()
1403+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
1404+
1405+
mod, _ = relay.frontend.from_tflite(
1406+
tflite_model,
1407+
shape_dict={"input": ifm_shape},
1408+
dtype_dict={"input": dtype},
1409+
)
1410+
mod = ethosu.partition_for_ethosu(mod)
1411+
1412+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
1413+
legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"]
1414+
)
1415+
1416+
mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
1417+
"tvmgen_default_ethos_u_main_0"
1418+
]
1419+
1420+
verify(mod["tvmgen_default_ethos_u_main_0"])
1421+
1422+
1423+
@pytest.mark.parametrize(
1424+
"ifm_shape, num_or_size_splits, axis",
1425+
[
1426+
((1, 4, 6, 8), (1, 3, 4), 3),
1427+
((10, 18, 4), (1, 4, 3, 2), 0),
1428+
((22, 7), (4, -1), 1),
1429+
((25,), (25,), 0),
1430+
],
1431+
)
1432+
def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis):
1433+
dtype = "int8"
1434+
1435+
def create_tflite_graph():
1436+
class Model(tf.Module):
1437+
@tf.function
1438+
def tf_function(self, x, num_or_size_splits, axis):
1439+
# TF split gets converted into TFLite's split_v
1440+
op = tf.split(x, num_or_size_splits, axis=axis)
1441+
return op
1442+
1443+
model = Model()
1444+
concrete_func = model.tf_function.get_concrete_function(
1445+
tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis
1446+
)
1447+
1448+
def representative_dataset():
1449+
for _ in range(100):
1450+
data = np.random.rand(*tuple(ifm_shape))
1451+
yield [data.astype(np.float32)]
1452+
1453+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
1454+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
1455+
converter.representative_dataset = representative_dataset
1456+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
1457+
converter.inference_input_type = tf.int8
1458+
converter.inference_output_type = tf.int8
1459+
tflite_model = converter.convert()
1460+
1461+
return tflite_model
1462+
1463+
def verify(ext_func):
1464+
# dig out the split
1465+
single_output_split = len(num_or_size_splits) == 1
1466+
split = (
1467+
ext_func.body.tuple_value
1468+
if single_output_split
1469+
else ext_func.body.args[0][0].args[0].tuple_value
1470+
)
1471+
assert split.op.name == "split"
1472+
1473+
# Split is specified by the size of sections, so converting num_or_size_splits
1474+
# into the indices where the tensor is split at since this is how split is represented
1475+
# in Relay
1476+
split_sections = [] if single_output_split else [num_or_size_splits[0]]
1477+
for split_size in num_or_size_splits[1:-1]:
1478+
sec = split_sections[-1] + split_size
1479+
split_sections.append(sec)
1480+
assert list(split.attrs.indices_or_sections) == split_sections
1481+
1482+
assert split.attrs.axis == axis
1483+
1484+
tflite_graph = create_tflite_graph()
1485+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
1486+
1487+
mod, _ = relay.frontend.from_tflite(
1488+
tflite_model,
1489+
shape_dict={"input": ifm_shape},
1490+
dtype_dict={"input": dtype},
1491+
)
1492+
mod = ethosu.partition_for_ethosu(mod)
1493+
1494+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
1495+
legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"]
1496+
)
1497+
1498+
mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
1499+
"tvmgen_default_ethos_u_main_0"
1500+
]
1501+
1502+
verify(mod["tvmgen_default_ethos_u_main_0"])
1503+
1504+
13471505
if __name__ == "__main__":
13481506
pytest.main([__file__])

0 commit comments

Comments
 (0)