@@ -420,7 +420,7 @@ def forward(self, x):
420420 f"MaxPool3d TRT outputs don't match with the original model." ,
421421 )
422422
423- def test_lowering_select_scatter_module (self ):
423+ def test_lowering_select_scatter_dimZero_module (self ):
424424 class selectScatter (torch .nn .Module ):
425425 def __init__ (self , * args , ** kwargs ) -> None :
426426 super ().__init__ (* args , ** kwargs )
@@ -484,5 +484,67 @@ def forward(self, x, src, dim, index):
484484 )
485485
486486
487+ def test_lowering_select_scatter_dimOne_module (self ):
488+ class selectScatter (torch .nn .Module ):
489+ def __init__ (self , * args , ** kwargs ) -> None :
490+ super ().__init__ (* args , ** kwargs )
491+
492+ def forward (self , x , src , dim , index ):
493+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
494+ return y
495+
496+ # Operations expected to be removed in the traced graph after decompositions
497+ expected_ops = {
498+ torch .ops .aten .slice .Tensor ,
499+ torch .ops .aten .squeeze .dim ,
500+ torch .ops .aten .cat .default ,
501+ }
502+ unexpected_ops = {torch .ops .aten .select_scatter .default }
503+
504+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
505+
506+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
507+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
508+ fx_graph ,
509+ inputs ,
510+ expected_ops = expected_ops ,
511+ unexpected_ops = unexpected_ops ,
512+ min_block_size = 1 ,
513+ )
514+
515+ self .assertEquals (
516+ len (unexpected_ops_seen ),
517+ 0 ,
518+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
519+ )
520+
521+ self .assertEquals (
522+ len (expected_ops_unseen ),
523+ 0 ,
524+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
525+ )
526+
527+ torch ._dynamo .reset ()
528+
529+ # Validate that the results between Torch and Torch-TRT are similar
530+ optimized_model = torch_tensorrt .compile (
531+ fx_graph ,
532+ "torch_compile" ,
533+ inputs ,
534+ min_block_size = 1 ,
535+ pass_through_build_failures = True ,
536+ )
537+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
538+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
539+
540+ max_diff = float (
541+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
542+ )
543+ self .assertAlmostEqual (
544+ max_diff ,
545+ 0 ,
546+ DECIMALS_OF_AGREEMENT ,
547+ f"Select_scatter TRT outputs don't match with the original model." ,
548+ )
487549if __name__ == "__main__" :
488550 run_tests ()
0 commit comments