Skip to content

Commit 624ec8a

Browse files
authored
chore: doc update (#2967)
1 parent f5167a8 commit 624ec8a

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

docsrc/ts/getting_started_with_cpp_api.rst

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ As you can see it is pretty similar to the Python API. When you call the ``forwa
100100

101101
Compiling with Torch-TensorRT in C++
102102
-------------------------------------
103-
We are also at the point were we can compile and optimize our module with Torch-TensorRT, but instead of in a JIT fashion we must do it ahead-of-time (AOT) i.e. before we start doing actual inference work
103+
We are also at the point where we can compile and optimize our module with Torch-TensorRT, but instead of in a JIT fashion we must do it ahead-of-time (AOT) i.e. before we start doing actual inference work
104104
since it takes a bit of time to optimize the module, it would not make sense to do this every time you run the module or even the first time you run it.
105105

106106
With our module loaded, we can feed it into the Torch-TensorRT compiler. When we do so we must provide some information on the expected input size and also configure any additional settings.
@@ -113,9 +113,10 @@ With our module loaded, we can feed it into the Torch-TensorRT compiler. When we
113113

114114
mod.to(at::kCUDA);
115115
mod.eval();
116-
117-
auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA});
118-
auto trt_mod = torch_tensorrt::CompileGraph(mod, std::vector<torch_tensorrt::CompileSpec::InputRange>{{in.sizes()}});
116+
std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
117+
torch_tensorrt::ts::CompileSpec cfg(inputs);
118+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
119+
auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA});
119120
auto out = trt_mod.forward({in});
120121

121122
Thats it! Now the graph runs primarily not with the JIT compiler but using TensorRT (though we execute the graph using the JIT runtime).
@@ -131,11 +132,11 @@ We can also set settings like operating precision to run in FP16.
131132
mod.to(at::kCUDA);
132133
mod.eval();
133134

134-
auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
135-
auto input_sizes = std::vector<torch_tensorrt::CompileSpec::InputRange>({in.sizes()});
136-
torch_tensorrt::CompileSpec info(input_sizes);
137-
info.enable_precisions.insert(torch::kHALF);
138-
auto trt_mod = torch_tensorrt::CompileGraph(mod, info);
135+
auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA}).to(torch::kHALF);
136+
std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
137+
torch_tensorrt::ts::CompileSpec cfg(inputs);
138+
cfg.enable_precisions.insert(torch::kHALF);
139+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
139140
auto out = trt_mod.forward({in});
140141

141142
And now we are running the module in FP16 precision. You can then save the module to load later.
@@ -179,11 +180,12 @@ If you want to save the engine produced by Torch-TensorRT to use in a TensorRT a
179180
mod.to(at::kCUDA);
180181
mod.eval();
181182

182-
auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
183-
auto input_sizes = std::vector<torch_tensorrt::CompileSpec::InputRange>({in.sizes()});
184-
torch_tensorrt::CompileSpec info(input_sizes);
185-
info.enabled_precisions.insert(torch::kHALF);
186-
auto trt_mod = torch_tensorrt::ConvertGraphToTRTEngine(mod, "forward", info);
183+
auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA}).to(torch::kHALF);
184+
185+
std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
186+
torch_tensorrt::ts::CompileSpec cfg(inputs);
187+
cfg.enabled_precisions.insert(torch::kHALF);
188+
auto trt_mod = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", cfg);
187189
std::ofstream out("/tmp/engine_converted_from_jit.trt");
188190
out << engine;
189191
out.close();

0 commit comments

Comments
 (0)