@@ -294,13 +294,13 @@ void MapInputsAndDetermineDTypes(
294294 std::shared_ptr<torch::jit::Graph>& g,
295295 ir::StaticParams& static_params,
296296 ir::CollectionTypeMap& first_use_type_map) {
297- cfg.convert_info .collection_inputs = std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
297+ cfg.convert_info .collection_input_spec_map = std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
298298
299299 auto collection_inputs = ir::get_collection_inputs (g, static_params);
300300 LOG_DEBUG (" In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
301301
302302 for (auto in : collection_inputs) {
303- std::vector<ir::Input>& spec = cfg.convert_info .collection_inputs .find (in)->second ;
303+ std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
304304 std::vector<c10::optional<at::ScalarType>> est_type_opt;
305305
306306 auto est_it = first_use_type_map.find (in);
@@ -327,21 +327,21 @@ void MapInputsAndDetermineDTypes(
327327 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
328328 std::stringstream ss;
329329 ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
330- ss << cfg.convert_info .collection_inputs .find (in)->second [i].dtype ;
331- ss << " . The compiler is going to use the user setting " << cfg.convert_info .collection_inputs .find (in)->second [i].dtype ;
330+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
331+ ss << " . The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
332332 auto warn_str = ss.str ();
333333 LOG_WARNING (warn_str);
334334 // Overwrite type map with user settings
335- first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_inputs .find (in)->second [i].dtype )};
335+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
336336
337337 } else {
338- if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_inputs .find (in)->second [i].dtype ) != est_type_opt[i].value ()) {
338+ if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) != est_type_opt[i].value ()) {
339339 std::stringstream ss;
340340 ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
341- ss << cfg.convert_info .collection_inputs .find (in)->second [i].dtype ;
341+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
342342 ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
343343 ss << est_type_opt[i].value () << std::endl;
344- ss << " The compiler is going to use the user setting " << cfg.convert_info .collection_inputs .find (in)->second [i].dtype ;
344+ ss << " The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
345345 ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
346346 ss << " compatibility with PyTorch's data type convention is required.\n " ;
347347 ss << " If you do indeed see errors at runtime either:\n " ;
@@ -350,7 +350,7 @@ void MapInputsAndDetermineDTypes(
350350 auto warn_str = ss.str ();
351351 LOG_WARNING (warn_str);
352352 // Overwrite type map with user settings
353- first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_inputs .find (in)->second [i].dtype )};
353+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
354354 }
355355 }
356356 } else {
@@ -436,7 +436,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
436436 !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
437437 cfg.partition_info .forced_fallback_operators .size () == 0 &&
438438 conversion::VerifyConverterSupportForBlock (g->block (), false ))) {
439- auto collection_input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .collection_inputs , first_use_types);
439+ auto collection_input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
440440 auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), collection_input_ivalues_map, cfg, static_params);
441441 new_g = graph_and_mapping.first ;
442442 LOG_INFO (" Segmented Graph: " << *new_g);
0 commit comments