@@ -18,22 +18,26 @@ torchtrt::core::runtime::CudaDevice to_internal_cuda_device(Device device);
1818namespace torchscript {
1919CompileSpec::CompileSpec (std::vector<c10::ArrayRef<int64_t >> fixed_sizes) {
2020 for (auto in : fixed_sizes) {
21- inputs.push_back (Input (in));
21+ graph_inputs. inputs .push_back (Input (in));
2222 }
23- // graph_inputs.flattened_inputs = inputs;
2423}
2524
2625CompileSpec::CompileSpec (std::vector<std::vector<int64_t >> fixed_sizes) {
2726 for (auto in : fixed_sizes) {
28- inputs.push_back (Input (in));
27+ graph_inputs. inputs .push_back (Input (in));
2928 }
30- // graph_inputs.flattened_inputs = inputs;
29+ }
30+
31+ CompileSpec::CompileSpec (std::vector<Input> inputs) {
32+ graph_inputs.inputs = std::move (inputs);
3133}
3234
3335CompileSpec::CompileSpec (torch::jit::IValue input_signature) {
3436 graph_inputs.input_signature = input_signature;
3537}
3638
39+
40+
3741void flatten_dfs (std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torchtrt::core::ir::Input>>& collection_inputs,
3842 torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int level, int index) {
3943 if (input_ivalue.isTuple ()) {
@@ -59,7 +63,6 @@ void flatten_dfs(std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::
5963 }
6064 c10::TypePtr type = input_list[0 ].type ();
6165 auto converted_elements = c10::impl::GenericList (type);
62- // std::vector<torch::jit::IValue> converted_elements;
6366 int idx = 0 ;
6467 for (auto item: input_list) {
6568 int cur_idx = level < 1 ? idx: index;
@@ -95,7 +98,7 @@ torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs exter
9598
9699 torch::jit::IValue converted_input_signature;
97100 flatten_dfs (flattened_inputs, collection_inputs, external_graph_input.input_signature , converted_input_signature, 0 , 0 );
98- internal_graph_input.flattened_inputs = flattened_inputs;
101+ internal_graph_input.inputs = flattened_inputs;
99102 internal_graph_input.input_signature = converted_input_signature;
100103 internal_graph_input.collection_inputs = collection_inputs;
101104
@@ -105,17 +108,15 @@ torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs exter
105108}
106109
107110torchtrt::core::CompileSpec to_internal_compile_spec (CompileSpec external) {
108- torchtrt::core::CompileSpec internal (to_vec_internal_inputs (external.inputs ));
109- if (internal.inputs .size () == 0 ) {
111+ torchtrt::core::CompileSpec internal (to_vec_internal_inputs (external.graph_inputs . inputs ));
112+ if (internal.graph_inputs . inputs .size () == 0 ) {
110113 LOG_DEBUG (" GraphInput.inputs size == 0, using GraphInput.input_signature to get Input spec" );
111114 internal.graph_inputs = to_internal_graph_inputs (external.graph_inputs );
112- internal.inputs = internal.graph_inputs .flattened_inputs ;
113115 } else {
114116 LOG_DEBUG (" GraphInput.inputs size != 0, using GraphInput.inputs to get Input spec" );
115- internal.graph_inputs .collection_inputs .resize (internal.inputs .size ());
116- for (int i = 0 ; i < internal.inputs .size (); i++) {
117- internal.graph_inputs .collection_inputs [i].push_back (internal.inputs [i]);
118- internal.graph_inputs .flattened_inputs = internal.inputs ;
117+ internal.graph_inputs .collection_inputs .resize (internal.graph_inputs .inputs .size ());
118+ for (int i = 0 ; i < internal.graph_inputs .inputs .size (); i++) {
119+ internal.graph_inputs .collection_inputs [i].push_back (internal.graph_inputs .inputs [i]);
119120 }
120121 }
121122
0 commit comments