@@ -23,19 +23,29 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
23
23
return std::move (a);
24
24
}
25
25
26
- std::vector<core::conversion::InputRange> toInputRangesDynamic (std::vector<at::Tensor> ten) {
26
+ std::vector<core::conversion::InputRange> toInputRangesDynamic (std::vector<at::Tensor> ten, bool dynamic_batch ) {
27
27
std::vector<core::conversion::InputRange> a;
28
28
29
29
for (auto i : ten) {
30
30
auto opt = core::util::toVec (i.sizes ());
31
31
32
- std::vector<int64_t > min_range (opt);
33
- std::vector<int64_t > max_range (opt);
32
+ if (dynamic_batch) {
33
+ std::vector<int64_t > min_range (opt);
34
+ std::vector<int64_t > max_range (opt);
34
35
35
- min_range[1 ] = ceil (opt[1 ] / 2.0 );
36
- max_range[1 ] = 2 * opt[1 ];
36
+ min_range[0 ] = ceil (opt[0 ] / 2.0 );
37
+ max_range[0 ] = 2 * opt[0 ];
37
38
38
- a.push_back (core::conversion::InputRange (min_range, opt, max_range));
39
+ a.push_back (core::conversion::InputRange (min_range, opt, max_range));
40
+ } else {
41
+ std::vector<int64_t > min_range (opt);
42
+ std::vector<int64_t > max_range (opt);
43
+
44
+ min_range[1 ] = ceil (opt[1 ] / 2.0 );
45
+ max_range[1 ] = 2 * opt[1 ];
46
+
47
+ a.push_back (core::conversion::InputRange (min_range, opt, max_range));
48
+ }
39
49
}
40
50
41
51
return std::move (a);
@@ -63,9 +73,10 @@ std::vector<at::Tensor> RunGraphEngine(
63
73
std::vector<at::Tensor> RunGraphEngineDynamic (
64
74
std::shared_ptr<torch::jit::Graph>& g,
65
75
core::conversion::GraphParams& named_params,
66
- std::vector<at::Tensor> inputs) {
76
+ std::vector<at::Tensor> inputs,
77
+ bool dynamic_batch) {
67
78
LOG_DEBUG (" Running TRT version" );
68
- auto in = toInputRangesDynamic (inputs);
79
+ auto in = toInputRangesDynamic (inputs, dynamic_batch );
69
80
auto info = core::conversion::ConversionInfo (in);
70
81
info.engine_settings .workspace_size = 1 << 20 ;
71
82
std::string eng = core::conversion::ConvertBlockToEngine (g->block (), info, named_params);
0 commit comments