@@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
9494 * \param data_in The input data.
9595 */
9696void PipelineExecutor::SetInput (std::string input_name, DLTensor* data_in) {
97- std::pair<int , int > indexs = this ->GetInputIndex (input_name);
98- if (indexs.first < 0 || indexs.first >= static_cast <int >(runtimes_.size ())) {
99- LOG (FATAL) << " input name " << input_name << " not found." ;
100- }
101- runtimes_[indexs.first ]->SetInput (indexs.second , data_in);
97+ global_runtime_->SetPipelineInput (input_name, data_in);
10298}
10399/* !
104100 * \brief get input from the runtime module.
@@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) {
118114 * \return int The module index.
119115 */
120116int PipelineExecutor::GetParamModuleIndex (const std::string& name) {
121- return param_connection_config [name];
117+ return param_connection_config_ [name];
122118}
123119/* !
124120 * \brief Using the global input name to get the index, and also get the input interface name
@@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
127123 * \return Returning the index and the input interface name of corresponding subgraph.
128124 */
129125Array<String> PipelineExecutor::GetInputPipeplineMap (std::string input_name) {
130- std::pair<int , std::string> map = input_connection_config [input_name];
126+ std::pair<int , std::string> map = input_connection_config_ [input_name];
131127 return {std::to_string (map.first ), map.second };
132128}
133129
@@ -137,11 +133,11 @@ Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
137133 * \return int The module index.
138134 */
139135int PipelineExecutor::GetParamsGroupPipelineMap (const std::string& name) {
140- return param_connection_config [name];
136+ return param_connection_config_ [name];
141137}
142138
143139/* !\brief Run the pipeline executor.*/
144- void PipelineExecutor::Run () { pipeline_scheduler_.PipelineRun (runtimes_, pipeline_config_ ); }
140+ void PipelineExecutor::Run () { pipeline_scheduler_.PipelineRun (runtimes_); }
145141/* !
146142 * \brief return A list of global output data.
147143 */
@@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_
226222 * \return std::pair<int, int> A pair of module index and the input index.
227223 */
228224std::pair<int , int > PipelineExecutor::GetInputIndex (const std::string& name) {
229- std::pair<int , std::string> index = input_connection_config [name];
225+ std::pair<int , std::string> index = input_connection_config_ [name];
230226 auto gruntime = runtimes_[index.first ];
231227 return std::make_pair (index.first , gruntime->GetInputIndex (index.second ));
232228}
@@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
250246 num_outputs_ = pipeline_config_.GetGlobalOutputNum ();
251247 // Initialize the pipeline function class used for pipeline thread pool management
252248 // and schedule etc. This function returns a list of runtime.
253- runtimes_ = pipeline_scheduler_.PipelineInit (modules, pipeline_config_);
249+ global_runtime_ =
250+ pipeline_scheduler_.PipelineInit (modules, pipeline_config_, input_connection_config_);
251+ runtimes_ = global_runtime_->GetRuntimeList ();
254252 return ;
255253}
256254
0 commit comments