@@ -1344,5 +1344,163 @@ def representative_dataset():
13441344 assert tuple (func_body .args [1 ].checked_type .shape ) == (256 ,)
13451345
13461346
1347+ @pytest .mark .parametrize (
1348+ "ifm_shape, num_or_size_splits, axis" ,
1349+ [
1350+ ((1 , 4 , 6 , 8 ), 3 , 2 ),
1351+ ((4 , 6 , 8 ), 2 , 0 ),
1352+ ((5 , 15 ), 3 , 1 ),
1353+ ((3 , 7 ), 1 , 1 ),
1354+ ((100 ,), 25 , 0 ),
1355+ ],
1356+ )
1357+ def test_tflite_split_legalize (ifm_shape , num_or_size_splits , axis ):
1358+ dtype = "int8"
1359+
1360+ def create_tflite_graph ():
1361+ class Model (tf .Module ):
1362+ @tf .function
1363+ def tf_function (self , x , num_or_size_splits , axis ):
1364+ op = tf .split (x , num_or_size_splits , axis = axis )
1365+ return op
1366+
1367+ model = Model ()
1368+ concrete_func = model .tf_function .get_concrete_function (
1369+ tf .TensorSpec (ifm_shape , tf .float32 ), num_or_size_splits , axis
1370+ )
1371+
1372+ def representative_dataset ():
1373+ for _ in range (100 ):
1374+ data = np .random .rand (* tuple (ifm_shape ))
1375+ yield [data .astype (np .float32 )]
1376+
1377+ converter = tf .lite .TFLiteConverter .from_concrete_functions ([concrete_func ])
1378+ converter .optimizations = [tf .lite .Optimize .DEFAULT ]
1379+ converter .representative_dataset = representative_dataset
1380+ converter .target_spec .supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS_INT8 ]
1381+ converter .inference_input_type = tf .int8
1382+ converter .inference_output_type = tf .int8
1383+ tflite_model = converter .convert ()
1384+
1385+ return tflite_model
1386+
1387+ def verify (ext_func ):
1388+ # dig out the split
1389+ single_output_split = num_or_size_splits == 1
1390+ split = (
1391+ ext_func .body .tuple_value
1392+ if single_output_split
1393+ else ext_func .body .args [0 ][0 ].args [0 ].tuple_value
1394+ )
1395+ assert split .op .name == "split"
1396+
1397+ # Split is specified by number of equal chunks
1398+ assert split .attrs .indices_or_sections == num_or_size_splits
1399+
1400+ assert split .attrs .axis == axis
1401+
1402+ tflite_graph = create_tflite_graph ()
1403+ tflite_model = tflite .Model .Model .GetRootAsModel (tflite_graph , 0 )
1404+
1405+ mod , _ = relay .frontend .from_tflite (
1406+ tflite_model ,
1407+ shape_dict = {"input" : ifm_shape },
1408+ dtype_dict = {"input" : dtype },
1409+ )
1410+ mod = ethosu .partition_for_ethosu (mod )
1411+
1412+ mod ["tvmgen_default_ethos_u_main_0" ] = dataflow_pattern .rewrite (
1413+ legalize .PartitionedSplitRewriter (), mod ["tvmgen_default_ethos_u_main_0" ]
1414+ )
1415+
1416+ mod ["tvmgen_default_ethos_u_main_0" ] = relay .transform .InferType ()(mod )[
1417+ "tvmgen_default_ethos_u_main_0"
1418+ ]
1419+
1420+ verify (mod ["tvmgen_default_ethos_u_main_0" ])
1421+
1422+
1423+ @pytest .mark .parametrize (
1424+ "ifm_shape, num_or_size_splits, axis" ,
1425+ [
1426+ ((1 , 4 , 6 , 8 ), (1 , 3 , 4 ), 3 ),
1427+ ((10 , 18 , 4 ), (1 , 4 , 3 , 2 ), 0 ),
1428+ ((22 , 7 ), (4 , - 1 ), 1 ),
1429+ ((25 ,), (25 ,), 0 ),
1430+ ],
1431+ )
1432+ def test_tflite_split_v_legalize (ifm_shape , num_or_size_splits , axis ):
1433+ dtype = "int8"
1434+
1435+ def create_tflite_graph ():
1436+ class Model (tf .Module ):
1437+ @tf .function
1438+ def tf_function (self , x , num_or_size_splits , axis ):
1439+ # TF split gets converted into TFLite's split_v
1440+ op = tf .split (x , num_or_size_splits , axis = axis )
1441+ return op
1442+
1443+ model = Model ()
1444+ concrete_func = model .tf_function .get_concrete_function (
1445+ tf .TensorSpec (ifm_shape , tf .float32 ), num_or_size_splits , axis
1446+ )
1447+
1448+ def representative_dataset ():
1449+ for _ in range (100 ):
1450+ data = np .random .rand (* tuple (ifm_shape ))
1451+ yield [data .astype (np .float32 )]
1452+
1453+ converter = tf .lite .TFLiteConverter .from_concrete_functions ([concrete_func ])
1454+ converter .optimizations = [tf .lite .Optimize .DEFAULT ]
1455+ converter .representative_dataset = representative_dataset
1456+ converter .target_spec .supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS_INT8 ]
1457+ converter .inference_input_type = tf .int8
1458+ converter .inference_output_type = tf .int8
1459+ tflite_model = converter .convert ()
1460+
1461+ return tflite_model
1462+
1463+ def verify (ext_func ):
1464+ # dig out the split
1465+ single_output_split = len (num_or_size_splits ) == 1
1466+ split = (
1467+ ext_func .body .tuple_value
1468+ if single_output_split
1469+ else ext_func .body .args [0 ][0 ].args [0 ].tuple_value
1470+ )
1471+ assert split .op .name == "split"
1472+
1473+ # Split is specified by the size of sections, so converting num_or_size_splits
1474+ # into the indices where the tensor is split at since this is how split is represented
1475+ # in Relay
1476+ split_sections = [] if single_output_split else [num_or_size_splits [0 ]]
1477+ for split_size in num_or_size_splits [1 :- 1 ]:
1478+ sec = split_sections [- 1 ] + split_size
1479+ split_sections .append (sec )
1480+ assert list (split .attrs .indices_or_sections ) == split_sections
1481+
1482+ assert split .attrs .axis == axis
1483+
1484+ tflite_graph = create_tflite_graph ()
1485+ tflite_model = tflite .Model .Model .GetRootAsModel (tflite_graph , 0 )
1486+
1487+ mod , _ = relay .frontend .from_tflite (
1488+ tflite_model ,
1489+ shape_dict = {"input" : ifm_shape },
1490+ dtype_dict = {"input" : dtype },
1491+ )
1492+ mod = ethosu .partition_for_ethosu (mod )
1493+
1494+ mod ["tvmgen_default_ethos_u_main_0" ] = dataflow_pattern .rewrite (
1495+ legalize .PartitionedSplitRewriter (), mod ["tvmgen_default_ethos_u_main_0" ]
1496+ )
1497+
1498+ mod ["tvmgen_default_ethos_u_main_0" ] = relay .transform .InferType ()(mod )[
1499+ "tvmgen_default_ethos_u_main_0"
1500+ ]
1501+
1502+ verify (mod ["tvmgen_default_ethos_u_main_0" ])
1503+
1504+
13471505if __name__ == "__main__" :
13481506 pytest .main ([__file__ ])
0 commit comments