@@ -1587,6 +1587,75 @@ def forward(self, x):
15871587 f"Log_softmax TRT outputs don't match with the original model." ,
15881588 )
15891589
1590+ @parameterized .expand (
1591+ [
1592+ ((1 , 3 , 5 ), True ),
1593+ ((1 , 3 , 5 ), False ),
1594+ ((2 , 4 , 6 , 8 ), True ),
1595+ ((2 , 4 , 6 , 8 ), False ),
1596+ ((3 , 6 , 9 , 12 , 15 ), True ),
1597+ ((3 , 6 , 9 , 12 , 15 ), False ),
1598+ ]
1599+ )
1600+ def test_lowering_instance_norm (self , shape , use_input_stats ):
1601+ class TestModule (torch .nn .Module ):
1602+ def forward (self , input , weight , bias , running_mean = None , running_var = None ):
1603+ return torch .ops .aten .instance_norm .default (
1604+ input ,
1605+ weight ,
1606+ bias ,
1607+ running_mean ,
1608+ running_var ,
1609+ use_input_stats ,
1610+ 0.1 ,
1611+ 1e-05 ,
1612+ True ,
1613+ )
1614+
1615+ # Operations expected to be removed in the traced graph after decompositions
1616+ unexpected_ops = {torch .ops .aten .instance_norm .default }
1617+
1618+ inputs = [
1619+ torch .randn (shape , device = "cuda" ),
1620+ torch .randn (shape [1 ], device = "cuda" ),
1621+ torch .randn (shape [1 ], device = "cuda" ),
1622+ ]
1623+ if not use_input_stats :
1624+ inputs += [
1625+ torch .randn (shape [1 ], device = "cuda" ),
1626+ torch .rand (shape [1 ], device = "cuda" ),
1627+ ]
1628+
1629+ fx_graph = torch .fx .symbolic_trace (TestModule ())
1630+ unexpected_ops_seen , _ = lower_graph_testing (
1631+ fx_graph , inputs , unexpected_ops = unexpected_ops , min_block_size = 1
1632+ )
1633+
1634+ self .assertEqual (
1635+ len (unexpected_ops_seen ),
1636+ 0 ,
1637+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
1638+ )
1639+
1640+ torch ._dynamo .reset ()
1641+
1642+ # Validate that the results between Torch and Torch-TRT are similar
1643+ optimized_model = torch_tensorrt .compile (
1644+ fx_graph , "dynamo" , inputs , min_block_size = 1
1645+ )
1646+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
1647+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
1648+
1649+ max_diff = float (
1650+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
1651+ )
1652+ self .assertAlmostEqual (
1653+ max_diff ,
1654+ 0 ,
1655+ DECIMALS_OF_AGREEMENT ,
1656+ "Instance_norm TRT outputs don't match with the original model." ,
1657+ )
1658+
15901659
15911660if __name__ == "__main__" :
15921661 run_tests ()
0 commit comments