@@ -35,7 +35,6 @@ struct Resize : public torch::data::transforms::TensorTransform<torch::Tensor> {
3535torch::jit::Module compile_int8_model (const std::string& data_dir, torch::jit::Module& mod) {
3636 auto calibration_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
3737 .use_subset (320 )
38- .map (Resize ({300 , 300 }))
3938 .map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
4039 {0.2023 , 0.1994 , 0.2010 }))
4140 .map (torch::data::transforms::Stack<>());
@@ -48,7 +47,7 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
4847 auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
4948
5049
51- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 300 , 300 }};
50+ std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
5251 // / Configure settings for compilation
5352 auto extra_info = trtorch::ExtraInfo ({input_shape});
5453 // / Set operating precision to INT8
@@ -99,7 +98,6 @@ int main(int argc, const char* argv[]) {
9998
10099 // / Dataloader moved into calibrator so need another for inference
101100 auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
102- .map (Resize ({300 , 300 }))
103101 .map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
104102 {0.2023 , 0.1994 , 0.2010 }))
105103 .map (torch::data::transforms::Stack<>());
@@ -131,7 +129,7 @@ int main(int argc, const char* argv[]) {
131129 if (images.sizes ()[0 ] < 32 ) {
132130 // / To handle smaller batches util Optimization profiles work with Int8
133131 auto diff = 32 - images.sizes ()[0 ];
134- auto img_padding = torch::zeros ({diff, 3 , 300 , 300 }, {torch::kCUDA });
132+ auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
135133 auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
136134 images = torch::cat ({images, img_padding}, 0 );
137135 targets = torch::cat ({targets, target_padding}, 0 );
@@ -152,7 +150,7 @@ int main(int argc, const char* argv[]) {
152150 std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
153151
154152 // / Time execution in JIT-FP32 and TRT-INT8
155- std::vector<std::vector<int64_t >> dims = {{32 , 3 , 300 , 300 }};
153+ std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
156154
157155 auto jit_runtimes = benchmark_module (mod, dims[0 ]);
158156 print_avg_std_dev (" JIT model FP32" , jit_runtimes, dims[0 ][0 ]);
0 commit comments