@@ -253,6 +253,7 @@ GraphAndMapping ConstructFallbackGraph(
253
253
}
254
254
// update the input ranges for each segments
255
255
convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
256
+
256
257
auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
257
258
auto temp_g = std::make_shared<torch::jit::Graph>();
258
259
auto device_spec = convert_cfg.engine_settings .device ;
@@ -288,7 +289,7 @@ GraphAndMapping ConstructFallbackGraph(
288
289
}
289
290
290
291
291
- void MapInputsAndDetermineDTypes (CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap & first_use_type_map) {
292
+ void MapInputsAndDetermineDTypes (CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, ir::TypeMap & first_use_type_map) {
292
293
// Associate input specs with inputs
293
294
cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
294
295
@@ -303,9 +304,31 @@ void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::G
303
304
} else if (!est_type_opt && !spec.dtype_is_user_defined ) {
304
305
// If we cannot calculate the type and the user did not define the type, then default to FP32
305
306
LOG_WARNING (
306
- " Cannot deterime input type from calcuations in graph for input "
307
+ " Cannot infer input type from calcuations in graph for input "
307
308
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
308
309
spec.dtype = nvinfer1::DataType::kFLOAT ;
310
+ } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
311
+ if (!est_type_opt) {
312
+ LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
313
+ } else {
314
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
315
+ std::stringstream ss;
316
+ ss <<" For input " << in->debugName () << " , found user specified input dtype as " ;
317
+ ss << cfg.convert_info .inputs .find (in)->second .dtype ;
318
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
319
+ ss << est_type_opt.value () << std::endl;
320
+ ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
321
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
322
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
323
+ ss << " If you do indeed see errors at runtime either:\n " ;
324
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
325
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
326
+ auto warn_str = ss.str ();
327
+ LOG_WARNING (warn_str);
328
+ // Overwrite type map with user settings
329
+ first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
330
+ }
331
+ }
309
332
} else {
310
333
// The user defined the type so no changes are necessary
311
334
}
@@ -317,10 +340,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
317
340
auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
318
341
319
342
auto g = graph_and_parameters.first ;
343
+ TRTORCH_CHECK (conversion::VerifyConverterSupportForBlock (g->block ()), " Not all operations in graph are supported by the compiler" );
320
344
auto params = graph_and_parameters.second ;
321
345
auto static_params = ir::get_static_params (g->inputs (), params);
322
346
// Infer the type of an input from the weights of the calculation
323
- auto first_use_types = util ::get_block_first_calc_dtypes_opt (g->block ());
347
+ auto first_use_types = ir ::get_block_first_calc_dtypes_opt (g->block ());
324
348
325
349
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
326
350
@@ -357,11 +381,21 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
357
381
auto params = graph_and_parameters.second ;
358
382
auto static_params = ir::get_static_params (g->inputs (), params);
359
383
// Infer the type of an input from the weights of the calculation
360
- auto first_use_types = util ::get_block_first_calc_dtypes_opt (g->block ());
384
+ auto first_use_types = ir ::get_block_first_calc_dtypes_opt (g->block ());
361
385
362
386
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
363
387
364
- if (cfg.partition_info .enabled ) {
388
+ if (cfg.partition_info .enabled
389
+ && (cfg.lower_info .forced_fallback_modules .size () == 0
390
+ && cfg.partition_info .forced_fallback_operators .size () == 0
391
+ && conversion::VerifyConverterSupportForBlock (g->block (), true ))) {
392
+ LOG_INFO (" Skipping partitioning since model is fully supported" );
393
+ }
394
+
395
+ if (cfg.partition_info .enabled
396
+ && !(cfg.lower_info .forced_fallback_modules .size () == 0
397
+ && cfg.partition_info .forced_fallback_operators .size () == 0
398
+ && conversion::VerifyConverterSupportForBlock (g->block (), false ))) {
365
399
auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
366
400
auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
367
401
new_g = graph_and_mapping.first ;
@@ -374,6 +408,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
374
408
return mod;
375
409
}
376
410
} else {
411
+ TRTORCH_CHECK (conversion::VerifyConverterSupportForBlock (g->block ()), " Not all operations in graph are supported by the compiler" );
377
412
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
378
413
auto device_spec = cfg.convert_info .engine_settings .device ;
379
414
auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
0 commit comments