@@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
530530 "torch_compile" ,
531531 inputs ,
532532 min_block_size = 1 ,
533- truncate_long_and_double = True ,
533+ truncate_double = True ,
534534 pass_through_build_failures = True ,
535535 )
536536 optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
593593 "torch_compile" ,
594594 inputs ,
595595 min_block_size = 1 ,
596- truncate_long_and_double = True ,
596+ truncate_double = True ,
597597 pass_through_build_failures = True ,
598598 )
599599 optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
663663 "torch_compile" ,
664664 inputs ,
665665 min_block_size = 1 ,
666- truncate_long_and_double = True ,
666+ truncate_double = True ,
667667 pass_through_build_failures = True ,
668668 )
669669 optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
679679 f"Slice_scatter TRT outputs don't match with the original model." ,
680680 )
681681
682+ def test_lowering_select_scatter_dimZero_module (self ):
683+ class selectScatter (torch .nn .Module ):
684+ def __init__ (self , * args , ** kwargs ) -> None :
685+ super ().__init__ (* args , ** kwargs )
686+
687+ def forward (self , x , src , dim , index ):
688+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
689+ return y
690+
691+ # Operations expected to be removed in the traced graph after decompositions
692+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
693+ unexpected_ops = {
694+ torch .ops .aten .select_scatter .default ,
695+ torch .ops .aten .slice_scatter .default ,
696+ }
697+
698+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 0 , 0 ]
699+
700+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
701+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
702+ fx_graph ,
703+ inputs ,
704+ expected_ops = expected_ops ,
705+ unexpected_ops = unexpected_ops ,
706+ min_block_size = 1 ,
707+ )
708+
709+ self .assertEqual (
710+ len (unexpected_ops_seen ),
711+ 0 ,
712+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
713+ )
714+
715+ self .assertEqual (
716+ len (expected_ops_unseen ),
717+ 0 ,
718+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
719+ )
720+
721+ torch ._dynamo .reset ()
722+
723+ # Validate that the results between Torch and Torch-TRT are similar
724+ optimized_model = torch_tensorrt .compile (
725+ fx_graph ,
726+ "torch_compile" ,
727+ inputs ,
728+ min_block_size = 1 ,
729+ truncate_and_double = True ,
730+ pass_through_build_failures = True ,
731+ )
732+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
733+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
734+
735+ max_diff = float (
736+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
737+ )
738+ self .assertAlmostEqual (
739+ max_diff ,
740+ 0 ,
741+ DECIMALS_OF_AGREEMENT ,
742+ f"Select_scatter TRT outputs don't match with the original model." ,
743+ )
744+
745+ def test_lowering_select_scatter_dimOne_module (self ):
746+ class selectScatter (torch .nn .Module ):
747+ def __init__ (self , * args , ** kwargs ) -> None :
748+ super ().__init__ (* args , ** kwargs )
749+
750+ def forward (self , x , src , dim , index ):
751+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
752+ return y
753+
754+ # Operations expected to be removed in the traced graph after decompositions
755+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
756+ unexpected_ops = {
757+ torch .ops .aten .select_scatter .default ,
758+ torch .ops .aten .slice_scatter .default ,
759+ }
760+
761+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
762+
763+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
764+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
765+ fx_graph ,
766+ inputs ,
767+ expected_ops = expected_ops ,
768+ unexpected_ops = unexpected_ops ,
769+ min_block_size = 1 ,
770+ )
771+
772+ self .assertEqual (
773+ len (unexpected_ops_seen ),
774+ 0 ,
775+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
776+ )
777+
778+ self .assertEqual (
779+ len (expected_ops_unseen ),
780+ 0 ,
781+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
782+ )
783+
784+ torch ._dynamo .reset ()
785+
786+ # Validate that the results between Torch and Torch-TRT are similar
787+ optimized_model = torch_tensorrt .compile (
788+ fx_graph ,
789+ "torch_compile" ,
790+ inputs ,
791+ min_block_size = 1 ,
792+ truncate_double = True ,
793+ pass_through_build_failures = True ,
794+ )
795+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
796+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
797+
798+ max_diff = float (
799+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
800+ )
801+ self .assertAlmostEqual (
802+ max_diff ,
803+ 0 ,
804+ DECIMALS_OF_AGREEMENT ,
805+ f"Select_scatter TRT outputs don't match with the original model." ,
806+ )
807+
808+ def test_lowering_select_scatter_multidimension_module (self ):
809+ class selectScatter (torch .nn .Module ):
810+ def __init__ (self , * args , ** kwargs ) -> None :
811+ super ().__init__ (* args , ** kwargs )
812+
813+ def forward (self , x , src , dim , index ):
814+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
815+ return y
816+
817+ # Operations expected to be removed in the traced graph after decompositions
818+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
819+ unexpected_ops = {
820+ torch .ops .aten .select_scatter .default ,
821+ torch .ops .aten .slice_scatter .default ,
822+ }
823+
824+ inputs = [torch .zeros (2 , 3 , 4 ).cuda (), torch .ones (2 , 4 ).cuda (), 1 , 0 ]
825+
826+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
827+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
828+ fx_graph ,
829+ inputs ,
830+ expected_ops = expected_ops ,
831+ unexpected_ops = unexpected_ops ,
832+ min_block_size = 1 ,
833+ )
834+
835+ self .assertEqual (
836+ len (unexpected_ops_seen ),
837+ 0 ,
838+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
839+ )
840+
841+ self .assertEqual (
842+ len (expected_ops_unseen ),
843+ 0 ,
844+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
845+ )
846+
847+ torch ._dynamo .reset ()
848+
849+ # Validate that the results between Torch and Torch-TRT are similar
850+ optimized_model = torch_tensorrt .compile (
851+ fx_graph ,
852+ "torch_compile" ,
853+ inputs ,
854+ min_block_size = 1 ,
855+ truncate_double = True ,
856+ pass_through_build_failures = True ,
857+ )
858+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
859+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
860+
861+ max_diff = float (
862+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
863+ )
864+ self .assertAlmostEqual (
865+ max_diff ,
866+ 0 ,
867+ DECIMALS_OF_AGREEMENT ,
868+ f"Select_scatter TRT outputs don't match with the original model." ,
869+ )
870+
682871
683872if __name__ == "__main__" :
684873 run_tests ()
0 commit comments