@@ -420,6 +420,71 @@ def forward(self, x):
420420 f"MaxPool3d TRT outputs don't match with the original model." ,
421421 )
422422
423+ def test_lowering_empty_like_module (self ):
424+ class emptyLike (torch .nn .Module ):
425+ def __init__ (self , * args , ** kwargs ) -> None :
426+ super ().__init__ (* args , ** kwargs )
427+
428+ def forward (self , x ):
429+ c = torch .ops .aten .add (x , x )
430+ y = torch .ops .aten .empty_like .default (c )
431+ d = y + c
432+ return d
433+
434+ # Operations expected to be removed in the traced graph after decompositions
435+ expected_ops = {torch .ops .aten .add .Tensor }
436+ unexpected_ops = {
437+ torch .ops .aten .empty_like .default ,
438+ torch .ops .aten .empty_permuted .default ,
439+ }
440+
441+ inputs = [torch .zeros (3 , 2 ).cuda ()]
442+
443+ fx_graph = torch .fx .symbolic_trace (emptyLike ())
444+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
445+ fx_graph ,
446+ inputs ,
447+ expected_ops = expected_ops ,
448+ unexpected_ops = unexpected_ops ,
449+ min_block_size = 1 ,
450+ )
451+
452+ self .assertEquals (
453+ len (unexpected_ops_seen ),
454+ 0 ,
455+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
456+ )
457+
458+ self .assertEquals (
459+ len (expected_ops_unseen ),
460+ 0 ,
461+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
462+ )
463+
464+ torch ._dynamo .reset ()
465+
466+ # Validate that the results between Torch and Torch-TRT are similar
467+ optimized_model = torch_tensorrt .compile (
468+ fx_graph ,
469+ "torch_compile" ,
470+ inputs ,
471+ min_block_size = 1 ,
472+ truncate_long_and_double = True ,
473+ pass_through_build_failures = True ,
474+ )
475+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
476+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
477+
478+ max_diff = float (
479+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
480+ )
481+ self .assertAlmostEqual (
482+ max_diff ,
483+ 0 ,
484+ DECIMALS_OF_AGREEMENT ,
485+ f"Select_scatter TRT outputs don't match with the original model." ,
486+ )
487+
423488
424489if __name__ == "__main__" :
425490 run_tests ()
0 commit comments