11#include " torch/script.h"
22#include " torch/torch.h"
3- #include " trtorch/ptq.h"
43#include " trtorch/trtorch.h"
54
65#include " NvInfer.h"
@@ -28,23 +27,24 @@ struct Resize : public torch::data::transforms::TensorTransform<torch::Tensor> {
2827 std::vector<int64_t > new_size_;
2928};
3029
31- torch::jit::Module compile_int8_qat_model (torch::jit::Module& mod) {
32- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
30+ torch::jit::Module compile_int8_qat_model (const std::string& data_dir, torch::jit::Module& mod) {
31+
32+ std::vector<trtorch::CompileSpec::Input> inputs = {
33+ trtorch::CompileSpec::Input (std::vector<int64_t >({32 , 3 , 32 , 32 }), trtorch::CompileSpec::DataType::kFloat )};
3334 // / Configure settings for compilation
34- auto compile_spec = trtorch::CompileSpec ({input_shape} );
35+ auto compile_spec = trtorch::CompileSpec (inputs );
3536 // / Set operating precision to INT8
37+ // compile_spec.enabled_precisions.insert(torch::kF16);
3638 compile_spec.enabled_precisions .insert (torch::kI8 );
3739 // / Set max batch size for the engine
3840 compile_spec.max_batch_size = 32 ;
3941 // / Set a larger workspace
4042 compile_spec.workspace_size = 1 << 28 ;
4143
42- mod.eval ();
43-
4444#ifdef SAVE_ENGINE
4545 std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
4646 auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
47- std::ofstream out (" /tmp/engine_converted_from_jit .trt" );
47+ std::ofstream out (" /tmp/int8_engine_converted_from_jit .trt" );
4848 out << engine;
4949 out.close ();
5050#endif
@@ -71,62 +71,53 @@ int main(int argc, const char* argv[]) {
7171 return -1 ;
7272 }
7373
74- // / Convert the model using TensorRT
75- auto trt_mod = compile_int8_qat_model (mod);
76- std::cout << " Model conversion to TensorRT completed." << std::endl;
77- // / Dataloader moved into calibrator so need another for inference
74+ mod.eval ();
75+
76+ // / Create the calibration dataset
7877 const std::string data_dir = std::string (argv[2 ]);
78+
79+ // / Dataloader moved into calibrator so need another for inference
7980 auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
81+ .use_subset (3200 )
8082 .map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 }, {0.2023 , 0.1994 , 0.2010 }))
8183 .map (torch::data::transforms::Stack<>());
8284 auto eval_dataloader = torch::data::make_data_loader (
8385 std::move (eval_dataset), torch::data::DataLoaderOptions ().batch_size (32 ).workers (2 ));
8486
8587 // / Check the FP32 accuracy in JIT
86- float correct = 0.0 , total = 0.0 ;
88+ torch::Tensor jit_correct = torch::zeros ({ 1 }, {torch:: kCUDA }), jit_total = torch::zeros ({ 1 }, {torch:: kCUDA }) ;
8789 for (auto batch : *eval_dataloader) {
8890 auto images = batch.data .to (torch::kCUDA );
8991 auto targets = batch.target .to (torch::kCUDA );
9092
9193 auto outputs = mod.forward ({images});
9294 auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
9395
94- total += targets.sizes ()[0 ];
95- correct += torch::sum (torch::eq (predictions, targets)). item (). toFloat ( );
96+ jit_total += targets.sizes ()[0 ];
97+ jit_correct += torch::sum (torch::eq (predictions, targets));
9698 }
97- std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %"
98- << " correct: " << correct << " total: " << total << std::endl;
99+ torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100 ;
100+
101+ // / Compile Graph
102+ auto trt_mod = compile_int8_qat_model (data_dir, mod);
99103
100104 // / Check the INT8 accuracy in TRT
101- correct = 0.0 ;
102- total = 0.0 ;
105+ torch::Tensor trt_correct = torch::zeros ({1 }, {torch::kCUDA }), trt_total = torch::zeros ({1 }, {torch::kCUDA });
103106 for (auto batch : *eval_dataloader) {
104107 auto images = batch.data .to (torch::kCUDA );
105108 auto targets = batch.target .to (torch::kCUDA );
106109
107- if (images.sizes ()[0 ] < 32 ) {
108- // / To handle smaller batches util Optimization profiles work with Int8
109- auto diff = 32 - images.sizes ()[0 ];
110- auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
111- auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
112- images = torch::cat ({images, img_padding}, 0 );
113- targets = torch::cat ({targets, target_padding}, 0 );
114- }
115-
116110 auto outputs = trt_mod.forward ({images});
117111 auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
118112 predictions = predictions.reshape (predictions.sizes ()[0 ]);
119113
120- if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
121- // / To handle smaller batches util Optimization profiles work with Int8
122- predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
123- }
124-
125- total += targets.sizes ()[0 ];
126- correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
114+ trt_total += targets.sizes ()[0 ];
115+ trt_correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
127116 }
128- std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %"
129- << " correct: " << correct << " total: " << total << std::endl;
117+ torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100 ;
118+
119+ std::cout << " Accuracy of JIT model on test set: " << jit_accuracy.item ().toFloat () << " %" << std::endl;
120+ std::cout << " Accuracy of quantized model on test set: " << trt_accuracy.item ().toFloat () << " %" << std::endl;
130121
131122 // / Time execution in JIT-FP32 and TRT-INT8
132123 std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
0 commit comments