Skip to content

Commit 351f31b

Browse files
authored
[Runtime][PipelineExecutor]Add forwarding queue logic for set input. (#10990)
* [Runtime][PipelineExecutor]Add forwarding queue logic for set input. When the set_input function get called, a runtime of pipeline may not yet finish the former computation work then the new set_input call would break the current computation logic, to avoid such issue, we add the forwarding queue logic to guarantee the order of input data consuming. * polish the documents.
1 parent 8aafe5b commit 351f31b

File tree

7 files changed

+242
-133
lines changed

7 files changed

+242
-133
lines changed

python/tvm/contrib/pipeline_executor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,7 @@ def set_input(self, key, value):
164164
value : array_like.
165165
The input value
166166
"""
167-
v = self._get_input(key)
168-
if v is None:
169-
raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
170-
v.copyfrom(value)
167+
self._set_input(key, tvm.nd.array(value))
171168

172169
def set_params(self, params_group_name, params_data):
173170
"""Set the parameter group value given the parameter group name. Note that the parameter

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
9494
* \param data_in The input data.
9595
*/
9696
void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) {
97-
std::pair<int, int> indexs = this->GetInputIndex(input_name);
98-
if (indexs.first < 0 || indexs.first >= static_cast<int>(runtimes_.size())) {
99-
LOG(FATAL) << "input name " << input_name << " not found.";
100-
}
101-
runtimes_[indexs.first]->SetInput(indexs.second, data_in);
97+
global_runtime_->SetPipelineInput(input_name, data_in);
10298
}
10399
/*!
104100
* \brief get input from the runtime module.
@@ -118,7 +114,7 @@ NDArray PipelineExecutor::GetInput(std::string input_name) {
118114
* \return int The module index.
119115
*/
120116
int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
121-
return param_connection_config[name];
117+
return param_connection_config_[name];
122118
}
123119
/*!
124120
* \brief Using the global input name to get the index, and also get the input interface name
@@ -127,7 +123,7 @@ int PipelineExecutor::GetParamModuleIndex(const std::string& name) {
127123
* \return Returning the index and the input interface name of corresponding subgraph.
128124
*/
129125
Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
130-
std::pair<int, std::string> map = input_connection_config[input_name];
126+
std::pair<int, std::string> map = input_connection_config_[input_name];
131127
return {std::to_string(map.first), map.second};
132128
}
133129

@@ -137,11 +133,11 @@ Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
137133
* \return int The module index.
138134
*/
139135
int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
140-
return param_connection_config[name];
136+
return param_connection_config_[name];
141137
}
142138

143139
/*!\brief Run the pipeline executor.*/
144-
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); }
140+
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_); }
145141
/*!
146142
* \brief return A list of global output data.
147143
*/
@@ -226,7 +222,7 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_
226222
* \return std::pair<int, int> A pair of module index and the input index.
227223
*/
228224
std::pair<int, int> PipelineExecutor::GetInputIndex(const std::string& name) {
229-
std::pair<int, std::string> index = input_connection_config[name];
225+
std::pair<int, std::string> index = input_connection_config_[name];
230226
auto gruntime = runtimes_[index.first];
231227
return std::make_pair(index.first, gruntime->GetInputIndex(index.second));
232228
}
@@ -250,7 +246,9 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
250246
num_outputs_ = pipeline_config_.GetGlobalOutputNum();
251247
// Initialize the pipeline function class used for pipeline thread pool management
252248
// and schedule etc. This function returns a list of runtime.
253-
runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_);
249+
global_runtime_ =
250+
pipeline_scheduler_.PipelineInit(modules, pipeline_config_, input_connection_config_);
251+
runtimes_ = global_runtime_->GetRuntimeList();
254252
return;
255253
}
256254

src/runtime/pipeline/pipeline_executor.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,16 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
176176
/*!\brief The dependency information of each graph runtime module of the pipeline.*/
177177
ConfigPipelineExecution pipeline_config_;
178178
/*!\brief The map of global input and subgraph input.*/
179-
InputConnectionConfig input_connection_config;
179+
InputConnectionConfig input_connection_config_;
180180
/*!\brief The map includes global parameters groups and runtime modules.*/
181-
ParamConnectionConfig param_connection_config;
181+
ParamConnectionConfig param_connection_config_;
182182
/*!\brief The module information used to create the graph runtimes.*/
183183
ModuleConfig mod_config_;
184184
/*!\brief How many outputs are in this pipeline executor.*/
185185
size_t num_outputs_ = 0;
186186
/*!The list of backend runtime module.*/
187187
std::vector<std::shared_ptr<BackendRuntime>> runtimes_;
188+
std::shared_ptr<GlobalRuntime> global_runtime_;
188189
/*!\brief Json loader.*/
189190
void LoadConfig(dmlc::JSONReader* reader) {
190191
reader->BeginObject();
@@ -193,9 +194,9 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
193194
if (key == "module_connection") {
194195
reader->Read(&pipeline_config_);
195196
} else if (key == "input_connection") {
196-
reader->Read(&input_connection_config);
197+
reader->Read(&input_connection_config_);
197198
} else if (key == "param_connection") {
198-
reader->Read(&param_connection_config);
199+
reader->Read(&param_connection_config_);
199200
} else {
200201
LOG(FATAL) << "do not support key " << key;
201202
}

src/runtime/pipeline/pipeline_scheduler.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,20 @@ namespace runtime {
2828
* \param modules The list of graph executor modules.
2929
* \param pipeline_conf The dependency information of each graph executor module.
3030
*/
31-
std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
32-
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config) {
31+
std::shared_ptr<GlobalRuntime> PipelineScheduler::PipelineInit(
32+
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config,
33+
const InputConnectionConfig& input_connection_config) {
3334
std::vector<std::shared_ptr<BackendRuntime>> runtimes;
3435
graph_modules_ = modules;
35-
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
3636
// Creating a list of runtimes.
3737
for (size_t i = 0; i < graph_modules_.size(); i++) {
3838
auto run_item = std::make_shared<BackendRuntime>(graph_modules_[i], i);
3939
runtimes.push_back(run_item);
4040
}
41+
// Creating the global runtime to represent the pipeline executor.
42+
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
43+
// Initializing the data structures used by pipeline logic.
44+
global_runtime_->InitializePipeline(input_connection_config, runtimes);
4145
// Creating a list of NDArray in order to storage the outputs data.
4246
auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings();
4347
for (size_t i = 0; i < global_output_map.size(); i++) {
@@ -52,15 +56,14 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
5256
for (auto runtime : runtimes) {
5357
runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_);
5458
}
55-
return runtimes;
59+
return global_runtime_;
5660
}
5761
/*!
5862
* \brief Running pipeline logic.
5963
* \param runtimes A list of backend runtime modules.
6064
* \param pipeline_config The dependency configuration of each runtime module.
6165
*/
62-
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
63-
ConfigPipelineExecution pipeline_config) {
66+
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes) {
6467
runtimes.front()->RunPipeline();
6568
}
6669
/*!

src/runtime/pipeline/pipeline_scheduler.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ class PipelineScheduler {
4141
* \param modules The list of graph executor module.
4242
* \param pipeline_config The dependency information of each graph executor module.
4343
*/
44-
std::vector<std::shared_ptr<BackendRuntime>> PipelineInit(
45-
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config);
44+
std::shared_ptr<GlobalRuntime> PipelineInit(const std::vector<Module>& modules,
45+
const ConfigPipelineExecution& pipeline_config,
46+
const InputConnectionConfig& input_connection_config);
4647
/*!
4748
* \brief Running the pipeline logic.
4849
* \param runtimes A list of backend runtime modules.
49-
* \param pipeline_config The dependency configuration of each runtime module.
5050
*/
51-
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
52-
ConfigPipelineExecution pipeline_config);
51+
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes);
5352
/*!
5453
* \brief Get a list of outputs.
5554
*/

0 commit comments

Comments
 (0)