Skip to content

Commit 9835ce0

Browse files
committed
refactor: [collection] fuse Input with GraphInputs
Signed-off-by: inocsin <[email protected]>
1 parent f397faf commit 9835ce0

File tree

4 files changed

+23
-28
lines changed

4 files changed

+23
-28
lines changed

core/compiler.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ namespace torch_tensorrt {
1414
namespace core {
1515

1616
struct CompileSpec {
17-
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
17+
CompileSpec(std::vector<ir::Input> inputs) {
18+
graph_inputs.inputs = inputs;
19+
}
1820
CompileSpec(torch::jit::IValue& input_signature) {
1921
graph_inputs.input_signature = input_signature;
2022
}
2123
ir::GraphInputs graph_inputs;
22-
std::vector<ir::Input> inputs; // can be replaced by graph_inputs
2324
conversion::ConversionInfo convert_info;
2425
lowering::LowerInfo lower_info;
2526
partitioning::PartitionInfo partition_info;

core/ir/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct Input : torch::CustomClassHolder {
4141
// Add to spec
4242
struct GraphInputs {
4343
torch::jit::IValue input_signature; // nested Input, full input spec
44-
std::vector<Input> flattened_inputs; // flattend Input
44+
std::vector<Input> inputs; // flattend Input
4545
std::vector<std::vector<Input>> collection_inputs; // only support two layer nesting, e.g. ((a, b), [c, d], e)
4646
};
4747

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,11 @@ struct TORCHTRT_API Input : torch::CustomClassHolder{
517517
/**
518518
* @brief A struct to hold complex inputs
519519
*
520-
* This struct can either hold a conplex inputs of shape or a flattened one,
520+
* This struct can either hold a complex inputs of shape or a flattened one,
521521
*/
522522
struct TORCHTRT_API GraphInputs {
523-
torch::jit::IValue input_signature; // nested Input, full input spec
523+
torch::jit::IValue input_signature; // nested Input, full input spec
524+
std::vector<Input> inputs; // flatten input spec
524525
};
525526

526527
/**
@@ -590,25 +591,17 @@ struct TORCHTRT_API CompileSpec {
590591
*
591592
* @param inputs
592593
*/
593-
CompileSpec(std::vector<Input> inputs) : inputs(std::move(inputs)) {}
594+
CompileSpec(std::vector<Input> inputs);
594595

595596
/**
596597
* @brief Construct a new Extra Info object from IValue.
597598
* The IValue store a complex Input
598599
*
599-
* @param inputs
600+
* @param input_signature
600601
*/
601602
CompileSpec(torch::jit::IValue input_signature);
602603
// Defaults should reflect TensorRT defaults for BuilderConfig
603604

604-
/**
605-
* @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max
606-
* sizes Users can also specify expected input type as well as tensor memory format
607-
*
608-
* Order in vector should match call order for the function
609-
*/
610-
std::vector<Input> inputs;
611-
612605
/**
613606
* @brief Specifications for inputs to the engine, can store a IValue which has stored complex Input
614607
* or a flatened Input

cpp/src/compile_spec.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@ torchtrt::core::runtime::CudaDevice to_internal_cuda_device(Device device);
1818
namespace torchscript {
1919
CompileSpec::CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
2020
for (auto in : fixed_sizes) {
21-
inputs.push_back(Input(in));
21+
graph_inputs.inputs.push_back(Input(in));
2222
}
23-
// graph_inputs.flattened_inputs = inputs;
2423
}
2524

2625
CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
2726
for (auto in : fixed_sizes) {
28-
inputs.push_back(Input(in));
27+
graph_inputs.inputs.push_back(Input(in));
2928
}
30-
// graph_inputs.flattened_inputs = inputs;
29+
}
30+
31+
CompileSpec::CompileSpec(std::vector<Input> inputs) {
32+
graph_inputs.inputs = std::move(inputs);
3133
}
3234

3335
CompileSpec::CompileSpec(torch::jit::IValue input_signature) {
3436
graph_inputs.input_signature = input_signature;
3537
}
3638

39+
40+
3741
void flatten_dfs(std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torchtrt::core::ir::Input>>& collection_inputs,
3842
torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int level, int index) {
3943
if (input_ivalue.isTuple()) {
@@ -59,7 +63,6 @@ void flatten_dfs(std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::
5963
}
6064
c10::TypePtr type = input_list[0].type();
6165
auto converted_elements = c10::impl::GenericList(type);
62-
// std::vector<torch::jit::IValue> converted_elements;
6366
int idx = 0;
6467
for (auto item: input_list) {
6568
int cur_idx = level < 1 ? idx: index;
@@ -95,7 +98,7 @@ torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs exter
9598

9699
torch::jit::IValue converted_input_signature;
97100
flatten_dfs(flattened_inputs, collection_inputs, external_graph_input.input_signature, converted_input_signature, 0, 0);
98-
internal_graph_input.flattened_inputs = flattened_inputs;
101+
internal_graph_input.inputs = flattened_inputs;
99102
internal_graph_input.input_signature = converted_input_signature;
100103
internal_graph_input.collection_inputs = collection_inputs;
101104

@@ -105,17 +108,15 @@ torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs exter
105108
}
106109

107110
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
108-
torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
109-
if (internal.inputs.size() == 0) {
111+
torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs));
112+
if (internal.graph_inputs.inputs.size() == 0) {
110113
LOG_DEBUG("GraphInput.inputs size == 0, using GraphInput.input_signature to get Input spec");
111114
internal.graph_inputs = to_internal_graph_inputs(external.graph_inputs);
112-
internal.inputs = internal.graph_inputs.flattened_inputs;
113115
} else {
114116
LOG_DEBUG("GraphInput.inputs size != 0, using GraphInput.inputs to get Input spec");
115-
internal.graph_inputs.collection_inputs.resize(internal.inputs.size());
116-
for (int i = 0; i < internal.inputs.size(); i++) {
117-
internal.graph_inputs.collection_inputs[i].push_back(internal.inputs[i]);
118-
internal.graph_inputs.flattened_inputs = internal.inputs;
117+
internal.graph_inputs.collection_inputs.resize(internal.graph_inputs.inputs.size());
118+
for (int i = 0; i < internal.graph_inputs.inputs.size(); i++) {
119+
internal.graph_inputs.collection_inputs[i].push_back(internal.graph_inputs.inputs[i]);
119120
}
120121
}
121122

0 commit comments

Comments
 (0)