Skip to content

Commit f397faf

Browse files
committed
chore: [collection] rename ConversionInfo.collection_inputs to ConversionInfo.collection_input_spec_map
Signed-off-by: inocsin <[email protected]>
1 parent 35d5aeb commit f397faf

File tree

4 files changed

+13
-19
lines changed

4 files changed

+13
-19
lines changed

core/compiler.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,13 @@ void MapInputsAndDetermineDTypes(
294294
std::shared_ptr<torch::jit::Graph>& g,
295295
ir::StaticParams& static_params,
296296
ir::CollectionTypeMap& first_use_type_map) {
297-
cfg.convert_info.collection_inputs = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
297+
cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
298298

299299
auto collection_inputs = ir::get_collection_inputs(g, static_params);
300300
LOG_DEBUG("In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());
301301

302302
for (auto in : collection_inputs) {
303-
std::vector<ir::Input>& spec = cfg.convert_info.collection_inputs.find(in)->second;
303+
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
304304
std::vector<c10::optional<at::ScalarType>> est_type_opt;
305305

306306
auto est_it = first_use_type_map.find(in);
@@ -327,21 +327,21 @@ void MapInputsAndDetermineDTypes(
327327
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
328328
std::stringstream ss;
329329
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
330-
ss << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
331-
ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
330+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
331+
ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
332332
auto warn_str = ss.str();
333333
LOG_WARNING(warn_str);
334334
// Overwrite type map with user settings
335-
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype)};
335+
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
336336

337337
} else {
338-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype) != est_type_opt[i].value()) {
338+
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != est_type_opt[i].value()) {
339339
std::stringstream ss;
340340
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
341-
ss << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
341+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
342342
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
343343
ss << est_type_opt[i].value() << std::endl;
344-
ss << "The compiler is going to use the user setting " << cfg.convert_info.collection_inputs.find(in)->second[i].dtype;
344+
ss << "The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
345345
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
346346
ss << "compatibility with PyTorch's data type convention is required.\n";
347347
ss << "If you do indeed see errors at runtime either:\n";
@@ -350,7 +350,7 @@ void MapInputsAndDetermineDTypes(
350350
auto warn_str = ss.str();
351351
LOG_WARNING(warn_str);
352352
// Overwrite type map with user settings
353-
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_inputs.find(in)->second[i].dtype)};
353+
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
354354
}
355355
}
356356
} else {
@@ -436,7 +436,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
436436
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
437437
cfg.partition_info.forced_fallback_operators.size() == 0 &&
438438
conversion::VerifyConverterSupportForBlock(g->block(), false))) {
439-
auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_inputs, first_use_types);
439+
auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
440440
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), collection_input_ivalues_map, cfg, static_params);
441441
new_g = graph_and_mapping.first;
442442
LOG_INFO("Segmented Graph: " << *new_g);

core/compiler.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@ namespace torch_tensorrt {
1414
namespace core {
1515

1616
struct CompileSpec {
17-
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {
18-
// graph_inputs = ir::GraphInputs(inputs);
19-
}
17+
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
2018
CompileSpec(torch::jit::IValue& input_signature) {
21-
// graph_inputs = ir::GraphInputs(input_signature);
22-
// inputs = graph_inputs.flattened_inputs;
2319
graph_inputs.input_signature = input_signature;
2420
}
2521
ir::GraphInputs graph_inputs;

core/conversion/conversion.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,8 @@ void AddInputs(
134134
ConversionCtx* ctx,
135135
c10::ArrayRef<const torch::jit::Value*> inputs,
136136
ConversionInfo& conversion_info) {
137-
// std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
138137
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs = conversion_info.inputs;
139-
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec = conversion_info.collection_inputs;
138+
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec = conversion_info.collection_input_spec_map;
140139

141140
std::vector<const torch::jit::Value*> input_tensors;
142141
for (auto in : inputs) {
@@ -396,7 +395,6 @@ void ConvertBlockToNetDef(
396395

397396
auto inputs = b->inputs();
398397
AddParamsToCtxValueMap(ctx, static_params);
399-
// AddInputs(ctx, inputs, build_info.inputs);
400398
AddInputs(ctx, inputs, build_info);
401399

402400
auto nodes = b->nodes();

core/conversion/conversion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace conversion {
1313

1414
struct ConversionInfo {
1515
ir::InputSpecMap inputs;
16-
ir::CollectionInputSpecMap collection_inputs;
16+
ir::CollectionInputSpecMap collection_input_spec_map;
1717
BuilderSettings engine_settings;
1818
};
1919

0 commit comments

Comments
 (0)