@@ -118,9 +118,11 @@ TRTEngine::TRTEngine(
118118 TORCHTRT_CHECK (
119119 (cuda_engine->getTensorIOMode (binding_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ),
120120 " Binding " << binding_name << " specified as input but found as output in TensorRT engine" );
121- LOG_DEBUG (" Input binding name: " << binding_name << " pyt arg idx: " << pyt_idx << " )" );
121+ LOG_DEBUG (
122+ " Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
123+ << " , Torch binding index: " << pyt_idx);
122124 in_binding_map[trt_idx] = pyt_idx;
123- in_binding_names[pyt_idx] = _in_binding_names[pyt_idx] ;
125+ in_binding_names[pyt_idx] = binding_name ;
124126 }
125127
126128 uint64_t outputs = _out_binding_names.size ();
@@ -210,19 +212,21 @@ std::string TRTEngine::to_str() const {
210212 ss << " Inputs: [" << std::endl;
211213 for (uint64_t i = 0 ; i < num_io.first ; i++) {
212214 ss << " id: " << i << std::endl;
213- ss << " shape: " << exec_ctx->getTensorShape (std::string (" input_" + str (i)).c_str ()) << std::endl;
215+ ss << " name: " << in_binding_names[i].c_str () << std::endl;
216+ ss << " shape: " << exec_ctx->getTensorShape (in_binding_names[i].c_str ()) << std::endl;
214217 ss << " dtype: "
215- << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getTensorDataType (std::string ( " input_ " + str (i)) .c_str ()))
218+ << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getTensorDataType (in_binding_names[i] .c_str ()))
216219 << std::endl;
217220 }
218221 ss << " ]" << std::endl;
219222 ss << " Outputs: [" << std::endl;
220223 for (uint64_t o = 0 ; o < num_io.second ; o++) {
221224 ss << " id: " << o << std::endl;
222- ss << " shape: " << exec_ctx->getTensorShape (std::string (" output_" + str (o)).c_str ()) << std::endl;
225+ ss << " name: " << out_binding_names[o].c_str () << std::endl;
226+ ss << " shape: " << exec_ctx->getTensorShape (out_binding_names[o].c_str ()) << std::endl;
223227 ss << " dtype: "
224228 << util::TRTDataTypeToScalarType (
225- exec_ctx->getEngine ().getTensorDataType (std::string ( " output_ " + str (o)) .c_str ()))
229+ exec_ctx->getEngine ().getTensorDataType (out_binding_names[o] .c_str ()))
226230 << std::endl;
227231 }
228232 ss << " }" << std::endl;
0 commit comments