Skip to content

Commit db20098

Browse files
authored
Merge pull request #74 from NVIDIA/seralization
Adds support for serialization and deseralization for compiled TorchScript modules
2 parents 40564c3 + b763332 commit db20098

File tree

119 files changed

+6989
-347
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+6989
-347
lines changed

BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pkg_tar(
88
"//core/conversion:include",
99
"//core/conversion/conversionctx:include",
1010
"//core/conversion/converters:include",
11+
"//core/conversion/var:include",
12+
"//core/conversion/tensorcontainer:include",
1113
"//core/conversion/evaluators:include",
1214
"//core/execution:include",
1315
"//core/lowering:include",
@@ -35,6 +37,15 @@ pkg_tar(
3537
)
3638

3739

40+
pkg_tar(
41+
name = "bin",
42+
package_dir = "bin/",
43+
srcs = [
44+
"//cpp/trtorchc:trtorchc",
45+
],
46+
mode = "0755",
47+
)
48+
3849

3950

4051
pkg_tar(
@@ -46,6 +57,7 @@ pkg_tar(
4657
],
4758
deps = [
4859
":lib",
60+
":bin",
4961
":include",
5062
":include_core",
5163
],

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ compile_settings.op_precision = torch::kFloat;
2323
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
2424
// Run like normal
2525
auto results = trt_mod.forward({in_tensor});
26+
// Save module for later
27+
trt_mod.save("trt_torchscript_module.ts");
2628
...
2729
```
2830
@@ -46,6 +48,7 @@ trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
4648
4749
input_data = input_data.half()
4850
result = trt_ts_module(input_data)
51+
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
4952
```
5053

5154
> Notes on running in lower precisions:

core/compiler.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include "NvInfer.h"
77

88
#include "ATen/core/function_schema.h"
9+
#include "ATen/core/jit_type.h"
910

11+
#include "torch/custom_class.h"
1012
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1113
#include "torch/csrc/jit/ir/ir.h"
1214
#include "torch/csrc/jit/passes/pass_manager.h"
@@ -40,32 +42,70 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4042

4143

4244
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
43-
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
44-
auto num_io = execution::GetEngineIO(uid);
45-
46-
auto self = g->addInput("self.1");
45+
auto engine = execution::TRTEngine(mod._ivalue()->name(), serialized_engine);
46+
// Get required metadata about the engine out
47+
auto num_io = engine.num_io;
48+
auto name = engine.name;
49+
50+
// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
51+
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(engine);
52+
mod.register_attribute(
53+
name,
54+
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
55+
c10::IValue(std::move(engine_ptr)),
56+
false
57+
);
58+
59+
// Add the module as an input into the graph
60+
auto self = g->addInput("self_1");
4761
self->setType(mod.type());
4862

49-
auto id_val = g->insertConstant(uid);
63+
// Start by retriveing the engine from the module attribute list
64+
auto engine_node = g->createGetAttr(self, name);
65+
g->block()->appendNode(engine_node);
5066

67+
// Add inputs to the graph corresponding to the number of input tensors expected by the engine
68+
// Also store those inputs in a vector so that they can be coalesced into a single list at runtime
5169
std::vector<torch::jit::Value*> engine_inputs;
52-
engine_inputs.push_back(id_val);
53-
5470
for (uint64_t i = 0; i < num_io.first; i++) {
55-
auto in_val = g->addInput("");
71+
auto in_val = g->addInput(std::string("input_") + std::to_string(i));
5672
in_val->setType(c10::TensorType::get());
5773
engine_inputs.push_back(in_val);
5874
}
5975

60-
auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs), num_io.second);
61-
g->block()->appendNode(engine_node);
62-
63-
if (engine_node->outputs().size() > 1) {
64-
auto return_tuple_node = g->createTuple(engine_node->outputs());
76+
// Create a node that will merge all of the input tensors into a single list argument to the trt::execute_engine op
77+
// Creates: prim::ListConstruct(<input tensors>)
78+
auto input_list_node = g->createList(c10::TensorType::get(), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs));
79+
g->block()->appendNode(input_list_node);
80+
81+
// Make a list of inputs to the actual trt::execute_engine op
82+
// Note: Ordering of list and then engine is because we can pop off the engine first which contains all the metadata
83+
// needed for execution
84+
std::vector<torch::jit::Value*> execute_node_inputs;
85+
execute_node_inputs.push_back(input_list_node->outputs()[0]);
86+
execute_node_inputs.push_back(engine_node->outputs()[0]);
87+
88+
// Create the actual execution node trt::execute_engine using the assembled inputs
89+
auto execute_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs), 1);
90+
g->block()->appendNode(execute_node);
91+
execute_node->outputs()[0]->setType(c10::ListType::ofTensors());
92+
93+
// Create a node to unpack the list into seperate tensors, in the case of there being only one tensor, the tensor will be returned,
94+
// otherwise they are returned as a tuple of tensors.
95+
// Creates: prim::ListUnpack(<engine output>)
96+
auto unpack_node = g->createListUnpack(execute_node->outputs()[0], num_io.second);
97+
g->block()->appendNode(unpack_node);
98+
99+
// If there are multiple output tensors from TensorRT we wrap them in a tuple to return
100+
if (unpack_node->outputs().size() > 1) {
101+
// Creates prim::TupleConstruct(<output tensors>) using outputs of the unpack node
102+
auto return_tuple_node = g->createTuple(unpack_node->outputs());
65103
g->block()->appendNode(return_tuple_node);
104+
// Set the output as the produced tuple
66105
g->registerOutput(return_tuple_node->outputs()[0]);
67106
} else {
68-
g->registerOutput(engine_node->outputs()[0]);
107+
// Set the output as the sole output tensor
108+
g->registerOutput(unpack_node->outputs()[0]);
69109
}
70110

71111
LOG_DEBUG(*g << "(AddEngineToGraph)\n");

core/conversion/InterfaceTypes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ InputRange::InputRange(std::vector<int64_t> d) {
3434
min = util::toDims(d);
3535
max = util::toDims(d);
3636
input_shape = util::toDims(d);
37-
37+
input_is_dynamic = false;
3838
}
3939

4040

@@ -67,6 +67,7 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
6767
dim.insert(max_shape[i]);
6868
if (dim.size() != 1) {
6969
dyn_shape.push_back(-1);
70+
input_is_dynamic = true;
7071
} else {
7172
dyn_shape.push_back(opt_shape[i]);
7273
}

core/conversion/conversion.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ void AddInputs(ConversionCtx* ctx,
155155
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
156156
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
157157

158+
if (dims.input_is_dynamic) {
159+
ctx->input_is_dynamic = true;
160+
}
161+
158162
ctx->value_tensor_map[in] = trt_in;
159163
}
160164

core/conversion/conversion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct InputRange {
1515
nvinfer1::Dims max;
1616
nvinfer1::Dims opt;
1717
nvinfer1::Dims input_shape;
18+
bool input_is_dynamic = false;
1819
// Should we restrict to unsigned?
1920
InputRange(std::vector<int64_t> d);
2021
InputRange(std::vector<int64_t> min_shape,

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ struct ConversionCtx {
4242

4343
~ConversionCtx();
4444

45+
bool input_is_dynamic = false;
4546
nvinfer1::IBuilder* builder;
4647
nvinfer1::INetworkDefinition* net;
4748
nvinfer1::IBuilderConfig* cfg;

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1919
auto orig_shape = input->getDimensions();
2020
auto shape = util::toVec(orig_shape);
2121
auto options = torch::TensorOptions().dtype(torch::kFloat32);
22-
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
23-
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
24-
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
25-
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
22+
23+
torch::Tensor gamma, beta, mean, var;
24+
25+
if (ctx->input_is_dynamic) {
26+
gamma = args[1].unwrapToTensor();
27+
beta = args[2].unwrapToTensor();
28+
mean = args[3].unwrapToTensor();
29+
var = args[4].unwrapToTensor();
30+
} else {
31+
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
32+
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
33+
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
34+
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
35+
}
36+
2637
auto eps = args[7].unwrapToDouble(1e-5f);
2738

39+
2840
LOG_DEBUG("momentum disregarded");
2941
LOG_DEBUG("training disregarded");
3042
LOG_DEBUG("cudnn disregarded");

core/conversion/converters/impl/concat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace conversion {
88
namespace converters {
99
namespace impl {
1010
namespace {
11-
auto cat_registrations = RegisterNodeConversionPatterns()
11+
auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/constant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace conversion {
77
namespace converters {
88
namespace impl {
99
namespace {
10-
auto constant_registrations = RegisterNodeConversionPatterns()
10+
auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1111
.pattern({
1212
"trt::const(Tensor self) -> Tensor",
1313
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

0 commit comments

Comments
 (0)