1
+ #pragma once
2
+ #include " ATen/tensor.h"
3
+ #include " NvInfer.h"
4
+
5
+ namespace trtorch {
6
+ namespace core {
7
+ namespace quantization {
8
+
9
+ enum class CalibratorKind {
10
+ kEntropy ,
11
+ kMinMax ,
12
+ }
13
+
14
+ in conveter or whatever
15
+ in order given std::vector<at::Tensor> -> map<input_name, at::Tensor>
16
+
17
+ struct QuantizationSettings {
18
+ CalibratorKind calibrator_type = CalibratorKind::kEntropy ;
19
+ const std::string& calibration_cache_file = " " ;
20
+ bool use_cache = false ;
21
+ std::unordered_map<std::string, at::Tensor> calibration_dataset;
22
+ };
23
+
24
+ class CalibrationBatchStream {
25
+
26
+ };
27
+
28
+ class Int8CalibratorImpl {
29
+ public:
30
+ TRTInt8CalibratorImpl (QuantizationSettings& settings);
31
+ int GetBatchSize () const ;
32
+ bool GetBatch (void * bindings[], const char * names[], int num_bindings);
33
+ const void * ReadCalibrationCache (size_t & length);
34
+ void WriteCalibrationCache (const void * cache, size_t length);
35
+ private:
36
+ std::unordered_map<std::string, at::Tensor> dataset_;
37
+ const std::string& cache_file_path_;
38
+ std::vector<char > cache_;
39
+ bool use_cache_;
40
+ size_t cache_size_ = 0 ;
41
+ };
42
+
43
+ class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 {
44
+ public:
45
+ TRTInt8EntropyCalibrator (Int8CalibratorImpl impl) : impl_(impl) {}
46
+ int getBatchSize () const override {return impl_.GetBatchSize ();}
47
+ bool getBatch (void * bindings[], const char * names[], int nbBindings) override {return impl_.GetBatch (bindings, names, nbBindings)};
48
+ const void * readCalibrationCache (size_t & length) override {return impl_.ReadCalibrationCache (size_t & length)};
49
+ void writeCalibrationCache (const void * cache, size_t length) override {impl_.WriteCalibrationCache (const void * cache, size_t length)};
50
+ private:
51
+ Int8CalibratorImpl impl_;
52
+ };
53
+
54
+ class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator {
55
+ public:
56
+ TRTInt8EntropyCalibrator (Int8CalibratorImpl impl) : impl_(impl) {}
57
+ int getBatchSize () const override {return impl_.GetBatchSize ();}
58
+ bool getBatch (void * bindings[], const char * names[], int nbBindings) override {return impl_.GetBatch (bindings, names, nbBindings)};
59
+ const void * readCalibrationCache (size_t & length) override {return impl_.ReadCalibrationCache (size_t & length)};
60
+ void writeCalibrationCache (const void * cache, size_t length) override {impl_.WriteCalibrationCache (const void * cache, size_t length)};
61
+ private:
62
+ Int8CalibratorImpl impl_;
63
+ };
64
+
65
+ nvinfer1::IInt8Calibrator create_int8_calibrator (QuantizationSettings settings);
66
+
67
+ } // quantization
68
+ } // core
69
+ } // trtorch
0 commit comments