@@ -66,40 +66,4 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
66
66
auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
67
67
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
68
68
}
69
-
70
- TEST (Partitioning, ComputeResNet50HalfFallbackGraphCorrectly) {
71
- torch::jit::script::Module mod;
72
- try {
73
- mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
74
- } catch (const c10::Error& e) {
75
- std::cerr << " error loading the model\n " ;
76
- return ;
77
- }
78
-
79
- mod.to (torch::kHalf );
80
-
81
- const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
82
- std::vector<torch::jit::IValue> jit_inputs_ivalues;
83
- std::vector<torch::jit::IValue> trt_inputs_ivalues;
84
- for (auto in_shape : input_shapes) {
85
- auto in = at::randint (5 , in_shape, {at::kCUDA }).to (torch::kHalf );
86
- jit_inputs_ivalues.push_back (in.clone ());
87
- trt_inputs_ivalues.push_back (in.clone ());
88
- }
89
-
90
- auto in_shape = torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 });
91
- in_shape.dtype = nvinfer1::DataType::kHALF ;
92
-
93
- std::vector<torch_tensorrt::core::ir::Input> input_ranges ({in_shape});
94
- auto g = mod.get_method (" forward" ).graph ();
95
- torch_tensorrt::core::CompileSpec cfg (input_ranges);
96
- cfg.partition_info .enabled = true ;
97
- cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
98
-
99
- auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
100
- auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
101
- auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
102
- // Lower threshold because FP16
103
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-1 ));
104
- }
105
69
#endif
0 commit comments