@@ -962,6 +962,119 @@ def forward(self, input):
962962 f"The optimized model results shape and torch model results shape should be equal in empty_stride" ,
963963 )
964964
965+ @parameterized .expand (
966+ [
967+ (
968+ "scatter_add_zero_dim_indexOne_constant" ,
969+ 0 ,
970+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
971+ torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int32 ).cuda (),
972+ {torch .ops .aten .add .Tensor },
973+ ),
974+ (
975+ "scatter_add_zero_dim_indexTwo_constant" ,
976+ 0 ,
977+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
978+ torch .tensor ([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = torch .int32 ).cuda (),
979+ {torch .ops .aten .add .Tensor , torch .ops .aten .scatter .src },
980+ ),
981+ (
982+ "scatter_add_one_dim_indexOne_constant" ,
983+ 1 ,
984+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
985+ torch .tensor ([[1 , 2 , 3 , 1 ]], dtype = torch .int32 ).cuda (),
986+ {
987+ torch .ops .aten .add .Tensor ,
988+ torch .ops .aten .scatter .src ,
989+ torch .ops .aten .full_like .default ,
990+ },
991+ ),
992+ (
993+ "scatter_add_one_dim_indexTwo_constant" ,
994+ 1 ,
995+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
996+ torch .tensor ([[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ]], dtype = torch .int32 ).cuda (),
997+ {
998+ torch .ops .aten .add .Tensor ,
999+ torch .ops .aten .scatter .src ,
1000+ torch .ops .aten .full_like .default ,
1001+ },
1002+ ),
1003+ (
1004+ "scatter_add_one_dim_indexTwo_constant" ,
1005+ 1 ,
1006+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ], [3 , 2 , 1 , 2 ]]).cuda (),
1007+ torch .tensor (
1008+ [[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ], [2 , 4 , 3 , 2 ]], dtype = torch .int32
1009+ ).cuda (),
1010+ {
1011+ torch .ops .aten .add .Tensor ,
1012+ torch .ops .aten .scatter .src ,
1013+ torch .ops .aten .full_like .default ,
1014+ },
1015+ ),
1016+ ]
1017+ )
1018+ def test_scatter_add (self , _ , dim , index , src , expected_ops_param ):
1019+ class TestModule (torch .nn .Module ):
1020+ def __init__ (self ):
1021+ super ().__init__ ()
1022+
1023+ def forward (self , input ):
1024+ return torch .ops .aten .scatter_add .default (input , dim , index , src )
1025+
1026+ # Operations expected to be included in the traced graph after decompositions
1027+ expected_ops = expected_ops_param
1028+ unexpected_ops = {torch .ops .aten .scatter_add .default }
1029+
1030+ input = torch .zeros (3 , 5 , dtype = torch .int32 ).cuda ()
1031+ inputs = [input ]
1032+
1033+ fx_graph = torch .fx .symbolic_trace (TestModule ())
1034+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
1035+ fx_graph ,
1036+ inputs ,
1037+ expected_ops = expected_ops ,
1038+ unexpected_ops = unexpected_ops ,
1039+ min_block_size = 2 ,
1040+ )
1041+
1042+ self .assertEquals (
1043+ len (expected_ops_unseen ),
1044+ 0 ,
1045+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
1046+ )
1047+
1048+ self .assertEquals (
1049+ len (unexpected_ops_seen ),
1050+ 0 ,
1051+ f"The following expected ops were not encountered: { unexpected_ops_seen } " ,
1052+ )
1053+
1054+ torch ._dynamo .reset ()
1055+
1056+ # Validate that the results between Torch and Torch-TRT are similar
1057+ optimized_model = torch_tensorrt .compile (
1058+ fx_graph ,
1059+ "torch_compile" ,
1060+ inputs ,
1061+ min_block_size = 1 ,
1062+ truncate_double = True ,
1063+ pass_through_build_failures = True ,
1064+ )
1065+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
1066+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
1067+
1068+ max_diff = float (
1069+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
1070+ )
1071+ self .assertAlmostEqual (
1072+ max_diff ,
1073+ 0 ,
1074+ DECIMALS_OF_AGREEMENT ,
1075+ f"Scatter_add TRT outputs don't match with the original model." ,
1076+ )
1077+
9651078
9661079if __name__ == "__main__" :
9671080 run_tests ()
0 commit comments