@@ -128,22 +128,6 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128128 return conversion::VerifyConverterSupportForBlock (g->block ());
129129}
130130
131- std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132- // Go through Lowering to simplify graph and extract weight parameters
133- auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
134-
135- auto convert_cfg = std::move (cfg.convert_info );
136- auto g = graph_and_parameters.first ;
137-
138- auto params = graph_and_parameters.second ;
139- auto named_params = conversion::get_named_params (g->inputs (), params);
140-
141- LOG_INFO (*g << " (CompileGraph)\n " );
142-
143- auto engine = conversion::ConvertBlockToEngine (g->block (), convert_cfg, named_params);
144- return std::move (engine);
145- }
146-
147131void AddSegmentedBlockToGraph (
148132 std::shared_ptr<torch::jit::Graph>& g,
149133 partitioning::SegmentedBlock& seg,
@@ -237,15 +221,15 @@ void AddIfBlockToGraph(
237221GraphAndMapping ConstructFallbackGraph (
238222 torch::jit::script::Module& new_mod,
239223 torch::jit::Block* block,
240- std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map ,
224+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map ,
241225 CompileSpec cfg,
242- conversion::GraphParams named_params ) {
226+ ir::StaticParams static_params ) {
243227 auto convert_cfg = cfg.convert_info ;
244228 auto partition_info = cfg.partition_info ;
245229
246230 auto new_g = std::make_shared<torch::jit::Graph>();
247231
248- auto segmented_blocks = partitioning::Partition (block, input_ivalues_map , partition_info);
232+ auto segmented_blocks = partitioning::Partition (block, example_tensor_map , partition_info);
249233
250234 // the mapping from lowering graph => fallback global graph
251235 std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -259,13 +243,17 @@ GraphAndMapping ConstructFallbackGraph(
259243 trt_engine_id << reinterpret_cast <const int *>(&seg_block);
260244
261245 if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
246+ auto shapes = seg_block.in_shapes ();
247+ auto types = seg_block.in_types ();
262248 std::vector<ir::Input> inputs;
263- for (auto & shape : seg_block.in_shape ()) {
264- inputs.push_back (ir::Input (shape));
249+ for (size_t i = 0 ; i < shapes.size (); i++) {
250+ auto in = ir::Input (shapes[i]);
251+ in.dtype = util::ScalarTypeToTRTDataType (types[i]);
252+ inputs.push_back (in);
265253 }
266254 // update the input ranges for each segments
267- convert_cfg.inputs = inputs;
268- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params );
255+ convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block. g (), inputs, static_params) ;
256+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params );
269257 auto temp_g = std::make_shared<torch::jit::Graph>();
270258 auto device_spec = convert_cfg.engine_settings .device ;
271259 auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
@@ -281,7 +269,7 @@ GraphAndMapping ConstructFallbackGraph(
281269 std::vector<GraphAndMapping> graph_and_mappings;
282270 for (auto cur_block : if_node->blocks ()) {
283271 graph_and_mappings.push_back (
284- ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map , cfg, named_params ));
272+ ConstructFallbackGraph (new_mod, cur_block, example_tensor_map , cfg, static_params ));
285273 }
286274 AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
287275
@@ -299,54 +287,28 @@ GraphAndMapping ConstructFallbackGraph(
299287 return {new_g, old_to_new_g};
300288}
301289
302- torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
303- // TODO: Should be doing a functional transform but need PR #31978
304- // [jit] More robust mangling
305- // torch::jit::script::Module new_mod = mod.clone();
306- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
307- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
308- for (const torch::jit::script::Method& method : mod.get_methods ()) {
309- // Compile only forward methods. forward method contains the entire graph.
310- if (method.name ().compare (" forward" ) == 0 ) {
311- auto new_g = std::make_shared<torch::jit::Graph>();
312- auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
290+ std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
291+ // Go through Lowering to simplify graph and extract weight parameters
292+ auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
313293
314- auto g = graph_and_parameters.first ;
315- auto params = graph_and_parameters.second ;
316- auto named_params = conversion::get_named_params (g->inputs (), params);
317- LOG_INFO (" (LoweredGraph)\n " << *g);
294+ auto convert_cfg = std::move (cfg.convert_info );
295+ auto g = graph_and_parameters.first ;
318296
319- std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320- for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
321- inputs.insert ({g->inputs ()[i], cfg.convert_info .inputs [i]});
322- }
323- auto input_ivalues_map = partitioning::generateRandomInputs (inputs);
324- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, named_params);
325- new_g = graph_and_mapping.first ;
326- LOG_INFO (" (FallbackGraph)\n " << *new_g);
297+ auto params = graph_and_parameters.second ;
298+ auto static_params = ir::get_static_params (g->inputs (), params);
327299
328- // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329- // module
330- if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
331- LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
332- return mod;
333- }
300+ LOG_INFO (*g << " (CompileGraph)\n " );
334301
335- auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
336- auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
337- new_mod.type ()->addMethod (new_method);
338- new_method->setSchema (schema);
339- }
340- }
302+ // Move the user defined inputs to the convert_cfg since some might be static;
303+ convert_cfg.inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
341304
342- return new_mod;
305+ auto engine = conversion::ConvertBlockToEngine (g->block (), convert_cfg, static_params);
306+ return std::move (engine);
343307}
344308
345- torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
346- // TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
347- if (cfg.partition_info .enabled ) {
348- return CompileGraphWithFallback (mod, cfg);
349- }
309+ torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
310+ torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
311+
350312 auto device_spec = cfg.convert_info .engine_settings .device ;
351313
352314 // GPU default WS size : 1 GB
@@ -362,25 +324,59 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
362324 }
363325 }
364326
365- // TODO: Should be doing a functional transform but need PR #31978
366- // [jit] More robust mangling
367- // torch::jit::script::Module new_mod = mod.clone();
368- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
369- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
370- for (const torch::jit::script::Method& method : mod.get_methods ()) {
371- // Compile only forward methods. forward method contains the entire graph.
327+ for (const torch::jit::Method& method : mod.get_methods ()) {
372328 if (method.name ().compare (" forward" ) == 0 ) {
373- auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
374329 auto new_g = std::make_shared<torch::jit::Graph>();
375- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
376- AddEngineToGraph (new_mod, new_g, engine, cuda_device);
330+
331+ auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
332+
333+ auto g = graph_and_parameters.first ;
334+ LOG_INFO (" Lowered Graph: " << *g);
335+ auto params = graph_and_parameters.second ;
336+ auto static_params = ir::get_static_params (g->inputs (), params);
337+
338+ cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
339+
340+ // If the user did not explicitly set the input type, then use the first
341+ // tensor calculation to infer type.
342+ auto first_use_types = util::get_block_first_calc_dtypes_opt (g->block ());
343+ for (auto & in : g->inputs ()) {
344+ auto est_type_opt = first_use_types[in];
345+ ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
346+ if (est_type_opt && !spec.dtype_is_user_defined ) {
347+ spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
348+ } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
349+ LOG_WARNING (
350+ " Cannot deterime input type from calcuations in graph for input "
351+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
352+ spec.dtype = nvinfer1::DataType::kFLOAT ;
353+ }
354+ }
355+
356+ if (cfg.partition_info .enabled ) {
357+ auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
358+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
359+ new_g = graph_and_mapping.first ;
360+ LOG_INFO (" Segmented Graph: " << *new_g);
361+
362+ // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
363+ // module
364+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
365+ LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
366+ return mod;
367+ }
368+ } else {
369+ auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
370+ auto device_spec = cfg.convert_info .engine_settings .device ;
371+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
372+ AddEngineToGraph (new_mod, new_g, engine, cuda_device);
373+ }
377374 auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
378375 auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
379376 new_mod.type ()->addMethod (new_method);
380377 new_method->setSchema (schema);
381378 }
382379 }
383-
384380 return new_mod;
385381}
386382
0 commit comments