@@ -1608,6 +1608,7 @@ def verify(ext_func):
16081608 [
16091609 [(1 , 2 , 2 , 1 ), (4 , 4 )],
16101610 [(1 , 4 , 7 , 3 ), (8 , 14 )],
1611+ [(1 , 3 , 5 , 3 ), (3 , 5 )],
16111612 ],
16121613)
16131614def test_tflite_resize2d_nearest_neighbor (ifm_shape , size ):
@@ -1647,21 +1648,25 @@ def representative_dataset():
16471648 return mod
16481649
16491650 def verify (ext_func ):
1650- identity = ext_func .body
1651- in_var = identity .args [0 ]
1651+ op = ext_func .body
1652+ in_var = op .args [0 ]
16521653
16531654 # check IFM
16541655 assert tuple (in_var .checked_type .shape ) == ifm_shape
16551656 assert in_var .checked_type .dtype == dtype
16561657
16571658 # check OFM
1658- attrs = dict (identity .attrs )
1659+ attrs = dict (op .attrs )
16591660 out_shape = (ifm_shape [0 ], size [0 ], size [1 ], ifm_shape [3 ])
1660- assert tuple (identity .checked_type .shape ) == out_shape
1661- assert identity .checked_type .dtype == dtype
1661+ assert tuple (op .checked_type .shape ) == out_shape
1662+ assert op .checked_type .dtype == dtype
16621663
16631664 # Check Op attributes
1664- assert attrs ["upscale" ] == "NEAREST"
1665+ if size [0 ] == ifm_shape [1 ] and size [1 ] == ifm_shape [2 ]:
1666+ assert op .op .name == "contrib.ethosu.identity"
1667+ else :
1668+ assert attrs ["pooling_type" ] == "AVG"
1669+ assert attrs ["upscale" ] == "NEAREST"
16651670
16661671 rewriter = legalize .Resize2dRewriter ()
16671672 pattern_table = [
@@ -1687,6 +1692,7 @@ def verify(ext_func):
16871692 [(1 , 4 , 7 , 3 ), (8 , 14 ), False ],
16881693 [(1 , 2 , 2 , 1 ), (3 , 3 ), True ],
16891694 [(1 , 4 , 7 , 3 ), (7 , 13 ), True ],
1695+ [(1 , 3 , 5 , 3 ), (3 , 5 ), False ],
16901696 ],
16911697)
16921698def test_tflite_resize2d_bilinear (ifm_shape , size , align_corners ):
@@ -1725,28 +1731,31 @@ def representative_dataset():
17251731 return mod
17261732
17271733 def verify (ext_func ):
1728- avg_pool = ext_func .body
1729- in_var = avg_pool .args [0 ]
1734+ op = ext_func .body
1735+ in_var = op .args [0 ]
17301736
17311737 # check IFM
17321738 assert tuple (in_var .checked_type .shape ) == ifm_shape
17331739 assert in_var .checked_type .dtype == dtype
17341740
17351741 # check OFM
1736- attrs = dict (avg_pool .attrs )
1742+ attrs = dict (op .attrs )
17371743 out_shape = (ifm_shape [0 ], size [0 ], size [1 ], ifm_shape [3 ])
1738- assert tuple (avg_pool .checked_type .shape ) == out_shape
1739- assert avg_pool .checked_type .dtype == dtype
1744+ assert tuple (op .checked_type .shape ) == out_shape
1745+ assert op .checked_type .dtype == dtype
17401746
17411747 # Check Op attributes
1742- assert attrs ["pooling_type" ] == "AVG"
1743- assert attrs ["upscale" ] == "NEAREST"
1744-
1745- # Check padding
1746- if align_corners :
1747- assert list (attrs ["padding" ]) == [0 , 0 , 0 , 0 ]
1748+ if size [0 ] == ifm_shape [1 ] and size [1 ] == ifm_shape [2 ]:
1749+ assert op .op .name == "contrib.ethosu.identity"
17481750 else :
1749- assert list (attrs ["padding" ]) == [0 , 0 , 1 , 1 ]
1751+ assert attrs ["pooling_type" ] == "AVG"
1752+ assert attrs ["upscale" ] == "NEAREST"
1753+
1754+ # Check padding
1755+ if align_corners :
1756+ assert list (attrs ["padding" ]) == [0 , 0 , 0 , 0 ]
1757+ else :
1758+ assert list (attrs ["padding" ]) == [0 , 0 , 1 , 1 ]
17501759
17511760 rewriter = legalize .Resize2dRewriter ()
17521761 pattern_table = [
0 commit comments