| 
 | 1 | +import unittest  | 
 | 2 | +import torch_tensorrt as torchtrt  | 
 | 3 | +import torch  | 
 | 4 | +import torchvision.models as models  | 
 | 5 | +import timm  | 
 | 6 | + | 
 | 7 | + | 
 | 8 | +COS_SIM_THRESHOLD = 0.99  | 
 | 9 | + | 
 | 10 | + | 
 | 11 | +def cosine_similarity_custom(trt_out, torch_out):  | 
 | 12 | +    torch.nn.functional.cosine_similarity(  | 
 | 13 | +        trt_out.flatten(), torch_out.flatten(), dim=0, eps=1e-6  | 
 | 14 | +    )  | 
 | 15 | + | 
 | 16 | + | 
 | 17 | +class TestCompileE2E(unittest.TestCase):  | 
 | 18 | +    def test_resnet18_fx(self):  | 
 | 19 | +        self.model = models.resnet18(pretrained=True).eval().to("cuda")  | 
 | 20 | +        self.input = torch.randn((1, 3, 224, 224)).to("cuda")  | 
 | 21 | + | 
 | 22 | +        compile_spec = {  | 
 | 23 | +            "inputs": [self.input],  | 
 | 24 | +            "enabled_precisions": {torch.float},  | 
 | 25 | +        }  | 
 | 26 | + | 
 | 27 | +        trt_mod = torchtrt.compile(self.model, ir="fx", **compile_spec)  | 
 | 28 | +        cos_sim = cosine_similarity_custom(  | 
 | 29 | +            self.model(self.input),  | 
 | 30 | +            trt_mod(self.input),  | 
 | 31 | +            dim=0,  | 
 | 32 | +            eps=1e-6,  | 
 | 33 | +        )  | 
 | 34 | +        self.assertTrue(  | 
 | 35 | +            cos_sim > COS_SIM_THRESHOLD,  | 
 | 36 | +            msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COS_SIM_THRESHOLD}",  | 
 | 37 | +        )  | 
 | 38 | + | 
 | 39 | +    def test_mobilenet_v2_fx(self):  | 
 | 40 | +        self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda")  | 
 | 41 | +        self.input = torch.randn((1, 3, 224, 224)).to("cuda")  | 
 | 42 | + | 
 | 43 | +        compile_spec = {  | 
 | 44 | +            "inputs": [self.input],  | 
 | 45 | +            "enabled_precisions": {torch.float},  | 
 | 46 | +        }  | 
 | 47 | + | 
 | 48 | +        trt_mod = torchtrt.compile(self.model, ir="fx", **compile_spec)  | 
 | 49 | +        cos_sim = cosine_similarity_custom(  | 
 | 50 | +            self.model(self.input),  | 
 | 51 | +            trt_mod(self.input),  | 
 | 52 | +            dim=0,  | 
 | 53 | +            eps=1e-6,  | 
 | 54 | +        )  | 
 | 55 | +        self.assertTrue(  | 
 | 56 | +            cos_sim > COS_SIM_THRESHOLD,  | 
 | 57 | +            msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COS_SIM_THRESHOLD}",  | 
 | 58 | +        )  | 
 | 59 | + | 
 | 60 | +    def test_efficientnet_b0_fx(self):  | 
 | 61 | +        self.model = (  | 
 | 62 | +            timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")  | 
 | 63 | +        )  | 
 | 64 | +        self.input = torch.randn((1, 3, 224, 224)).to("cuda")  | 
 | 65 | + | 
 | 66 | +        compile_spec = {  | 
 | 67 | +            "inputs": [self.input],  | 
 | 68 | +            "enabled_precisions": {torch.float},  | 
 | 69 | +        }  | 
 | 70 | + | 
 | 71 | +        trt_mod = torchtrt.compile(self.model, ir="fx", **compile_spec)  | 
 | 72 | +        cos_sim = cosine_similarity_custom(  | 
 | 73 | +            self.model(self.input),  | 
 | 74 | +            trt_mod(self.input),  | 
 | 75 | +            dim=0,  | 
 | 76 | +            eps=1e-6,  | 
 | 77 | +        )  | 
 | 78 | +        self.assertTrue(  | 
 | 79 | +            cos_sim > COS_SIM_THRESHOLD,  | 
 | 80 | +            msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COS_SIM_THRESHOLD}",  | 
 | 81 | +        )  | 
 | 82 | + | 
 | 83 | +    def test_resnet18_half_fx(self):  | 
 | 84 | +        self.model = models.resnet18(pretrained=True).eval().to("cuda").half()  | 
 | 85 | +        self.input = torch.randn((1, 3, 224, 224)).to("cuda").half()  | 
 | 86 | + | 
 | 87 | +        compile_spec = {  | 
 | 88 | +            "inputs": [self.input],  | 
 | 89 | +            "enabled_precisions": {torch.half},  | 
 | 90 | +        }  | 
 | 91 | + | 
 | 92 | +        trt_mod = torchtrt.compile(self.model, ir="fx", **compile_spec)  | 
 | 93 | +        cos_sim = cosine_similarity_custom(  | 
 | 94 | +            self.model(self.input),  | 
 | 95 | +            trt_mod(self.input),  | 
 | 96 | +            dim=0,  | 
 | 97 | +            eps=1e-6,  | 
 | 98 | +        )  | 
 | 99 | +        self.assertTrue(  | 
 | 100 | +            cos_sim > COS_SIM_THRESHOLD,  | 
 | 101 | +            msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COS_SIM_THRESHOLD}",  | 
 | 102 | +        )  | 
 | 103 | + | 
 | 104 | + | 
 | 105 | +if __name__ == "__main__":  | 
 | 106 | +    unittest.main()  | 
0 commit comments