Skip to content

Commit b1bd18e

Browse files
huajsjcomaniac
andauthored
[Runtime][Pipeline executor] Global parameters group name and runtime modules parameters map. (#9846)
* [Runtime][Pipeline executor] Global parameters group name and runtime modules parameters map. Solution: To support on the fly parameters setting for each runtime module in pipeline executor, we create a feature that use global parameters group name to map the runtime module parameter, after such map relation get created user can do the on the fly parameters setting by using the parameters group name. trigger build. fix ut issue. polish comments. Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> Update src/runtime/pipeline/pipeline_executor.h Co-authored-by: Cody Yu <[email protected]> Update src/runtime/pipeline/pipeline_struct.h Co-authored-by: Cody Yu <[email protected]> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> address review comments. * Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <[email protected]> * fix plint issue. Co-authored-by: Cody Yu <[email protected]>
1 parent 4f29562 commit b1bd18e

File tree

5 files changed

+261
-49
lines changed

5 files changed

+261
-49
lines changed

python/tvm/contrib/pipeline_executor.py

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def build(pipe_configs):
5454
raise RuntimeError('"module_connection" is missing')
5555
if "input_connection" not in config:
5656
raise RuntimeError('"input_connection" is missing')
57+
if "param_connection" not in config:
58+
raise RuntimeError('"param_connection" is missing')
5759

5860
mod_n_configs = config["module_connection"]
5961
config_len = len(mod_n_configs)
@@ -91,6 +93,7 @@ def build(pipe_configs):
9193
# map of global input and subgraph input, and the "module_connection" is used to
9294
# record module dependency.
9395
string_config = {}
96+
string_config["param_connection"] = config["param_connection"]
9497
string_config["input_connection"] = config["input_connection"]
9598
string_config["module_connection"] = module_string_config
9699

@@ -114,6 +117,8 @@ def __init__(self, module):
114117
# Get the packed functions from the pipeline executor.
115118
self._get_num_outputs = self.module["get_num_outputs"]
116119
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
120+
self._get_params_group_pipeline_map = self.module["get_params_group_pipeline_map"]
121+
self._set_param = self.module["set_param"]
117122

118123
def get_input_pipeline_map(self, name):
119124
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
@@ -125,6 +130,39 @@ def get_input_pipeline_map(self, name):
125130
"""
126131
return self._get_input_pipeline_map(name)
127132

133+
def get_params_group_pipeline_map(self, name):
134+
"""Use the name of the parameters group to get the corresponding runtime module index.
135+
136+
Parameters
137+
----------
138+
name: str
139+
The parameter group name.
140+
141+
Returns
142+
-------
143+
module_index: int
144+
The index of the runtime module.
145+
"""
146+
return self._get_params_group_pipeline_map(name)
147+
148+
def set_params(self, params_group_name, params_data):
149+
"""Set the parameter group value given the parameter group name. Note that the parameter
150+
group name is declared in the pipeline executor config.
151+
152+
Parameters
153+
----------
154+
params_group_name : str
155+
The parameters group name.
156+
157+
params_data : Dict[str, NDArray]
158+
A map from parameter name to data.
159+
"""
160+
if not params_data:
161+
raise RuntimeError('"params_data is empty!"')
162+
163+
for key, val in params_data.items():
164+
self._set_param(params_group_name, key, val)
165+
128166
@property
129167
def num_outputs(self):
130168
"""Get the number of outputs.
@@ -311,9 +349,19 @@ def connect(self, binding):
311349
if self.io_owner == binding.io_owner:
312350
raise RuntimeError("Can not bind itself.")
313351

352+
if self.io_type == "param" and not self.is_pipeline_executor_interface():
353+
raise RuntimeError(
354+
'The "param" binding can only be used by a pipeline executor interface!'
355+
)
356+
314357
if not self.is_pipeline_executor_interface() and self.io_type == "input":
315358
raise RuntimeError("Module can only bind from output interface!")
316359

360+
if self.io_type == "param" and binding.io_type != "param":
361+
raise RuntimeError(
362+
'A global "param" interface can only be bind with a module "param" interface!'
363+
)
364+
317365
if (
318366
not self.is_pipeline_executor_interface()
319367
and not binding.is_pipeline_executor_interface()
@@ -412,6 +460,7 @@ def __init__(self, mod=None):
412460
self.output_type = InferType()(mod)["main"].checked_type.ret_type
413461
self.input_bindings = PipelineConfig.BindingList(self, "input")
414462
self.output_bindings = PipelineConfig.BindingList(self, "output")
463+
self.param_binding = PipelineConfig.Binding(self, "param", "param")
415464

416465
def __eq__(self, other):
417466
if isinstance(other, PipelineConfig.ModuleWrapper):
@@ -427,6 +476,9 @@ def __getitem__(self, key):
427476
if key == "output":
428477
return self.output_bindings
429478

479+
if key == "param":
480+
return self.param_binding
481+
430482
raise RuntimeError(f"{key} not found!")
431483

432484
raise RuntimeError('The data type of "key" is not supported!')
@@ -483,14 +535,21 @@ def __init__(self):
483535
self.mod_wrapper = {}
484536
self.input_bindings = self.BindingList(self, "input")
485537
self.output_bindings = self.BindingList(self, "output")
538+
# There is a map of global parameters group and module index.
539+
self.param_group_bindings = self.BindingList(self, "param")
486540

487541
def __str__(self):
488542
# Get configuration information as a string.
489543

490544
# Use topological sort to get correct module order.
491545
self.dag_topology_sort()
546+
# Getting the parameters dependencies.
547+
param_dump = "Params\n"
548+
for param_name in self.param_group_bindings.bindings:
549+
inf = self.param_group_bindings.bindings[param_name]
550+
param_dump += str(inf) + "\n"
492551
# Get the input dependencies.
493-
input_dump = "Inputs\n"
552+
input_dump = "\nInputs\n"
494553
for input_name in self.input_bindings.bindings:
495554
inf = self.input_bindings.bindings[input_name]
496555
input_dump += str(inf) + "\n"
@@ -516,7 +575,7 @@ def __str__(self):
516575
for name in sorted(output.keys()):
517576
output_dump += f" |output({name}) : {output[name]}\n"
518577

519-
return input_dump + output_dump + connections_dump
578+
return param_dump + input_dump + output_dump + connections_dump
520579

521580
def __getitem__(self, key):
522581
if isinstance(key, tvm.ir.module.IRModule):
@@ -529,8 +588,12 @@ def __getitem__(self, key):
529588
return self.input_bindings
530589
if key == "output":
531590
return self.output_bindings
591+
if key == "param_group":
592+
return self.param_group_bindings
593+
594+
raise RuntimeError(f"{key} not found!")
532595

533-
raise RuntimeError(f"{key} not found.")
596+
raise RuntimeError(f'The key type "{type(key)}" is not supported!')
534597

535598
def get_config(self):
536599
"""Get the configuration information in dictionary form, this configuration
@@ -541,7 +604,6 @@ def get_config(self):
541604
self.dag_topology_sort()
542605
mconfig = {}
543606
module_connection = {}
544-
input_connection = {}
545607
for mod in self.mod_wrapper:
546608
# Generate pipeline configuration.
547609
mconf = {}
@@ -579,22 +641,33 @@ def get_config(self):
579641
"dev": module.dev,
580642
}
581643

582-
# Create a map of pipeline input and subgraph input.
583-
input_connection = []
584-
for input_name in self.input_bindings.bindings:
585-
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
586-
if "interface_name" not in input_dict["connection"][0]:
587-
raise RuntimeError("interface_name is missing in connection config!")
588-
# Creating the map of global interface and subgraph interface.
589-
input_map = {
590-
"global_interface_name": input_dict["interface_name"],
591-
"mod_idx": input_dict["connection"][0]["mod_idx"],
592-
"module_interface_name": input_dict["connection"][0]["interface_name"],
593-
}
594-
input_connection.append(input_map)
644+
# Creating a map including pipeline inputs and subgraph inputs.
645+
input_connection = []
646+
for input_name in self.input_bindings.bindings:
647+
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
648+
if "interface_name" not in input_dict["connection"][0]:
649+
raise RuntimeError("interface_name is missing in connection config!")
650+
# Creating the map including global interfaces and subgraph interfaces.
651+
input_map = {
652+
"global_interface_name": input_dict["interface_name"],
653+
"mod_idx": input_dict["connection"][0]["mod_idx"],
654+
"module_interface_name": input_dict["connection"][0]["interface_name"],
655+
}
656+
input_connection.append(input_map)
657+
658+
# Create a map including global parameters groups and modules.
659+
param_connection = []
660+
for param_name in self.param_group_bindings.bindings:
661+
param_dict = self.param_group_bindings.bindings[param_name].get_binding_dict()
662+
param_map = {
663+
"global_param_name": param_dict["interface_name"],
664+
"mod_idx": param_dict["connection"][0]["mod_idx"],
665+
}
666+
param_connection.append(param_map)
595667

596668
mconfig["module_connection"] = module_connection
597669
mconfig["input_connection"] = input_connection
670+
mconfig["param_connection"] = param_connection
598671
return mconfig
599672

600673
def dag_topology_sort(self):
@@ -613,8 +686,12 @@ def dag_topology_sort(self):
613686

614687
mlist += temp_list
615688

689+
mod_wrapper_sort = {}
616690
for mod, i in zip(mlist, range(len(mlist))):
617691
self.mod_wrapper[mod].set_idx_name(i)
692+
mod_wrapper_sort[mod] = self.mod_wrapper[mod]
693+
694+
self.mod_wrapper = mod_wrapper_sort
618695

619696
def get_mod_idx(self, mod):
620697
# Return the module index.

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,27 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
3737
} else if (name == "get_input_pipeline_map") {
3838
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
3939
if (String::CanConvertFrom(args[0])) {
40-
*rv = this->GetInputPipeplineMapping(args[0].operator String());
40+
*rv = this->GetInputPipeplineMap(args[0].operator String());
4141
} else {
4242
LOG(FATAL) << "Function only support the input name value in the form of string";
4343
}
4444
});
45+
} else if (name == "get_params_group_pipeline_map") {
46+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
47+
if (String::CanConvertFrom(args[0])) {
48+
*rv = this->GetParamsGroupPipelineMap(args[0].operator String());
49+
} else {
50+
LOG(FATAL) << "Function only support the input name value in the form of string";
51+
}
52+
});
53+
} else if (name == "set_param") {
54+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
55+
if (String::CanConvertFrom(args[0]) && String::CanConvertFrom(args[1])) {
56+
this->SetParam(args[0].operator String(), args[1].operator String(), args[2]);
57+
} else {
58+
LOG(FATAL) << "Function only support the parameter name and the key in the form of string";
59+
}
60+
});
4561
} else {
4662
LOG(FATAL) << "Unknown packed function: " << name;
4763
return PackedFunc();
@@ -55,11 +71,20 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
5571
* \param The global input name.
5672
* \return Returning the index and the input interface name of corresponding subgraph.
5773
*/
58-
Array<String> PipelineExecutor::GetInputPipeplineMapping(std::string input_name) {
74+
Array<String> PipelineExecutor::GetInputPipeplineMap(std::string input_name) {
5975
std::pair<int, std::string> map = input_connection_config[input_name];
6076
return {std::to_string(map.first), map.second};
6177
}
6278

79+
/*!
80+
* \brief Return the module index for the parameters group name.
81+
* \param name The parameters group name.
82+
* \return int The module index.
83+
*/
84+
int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
85+
return param_connection_config[name];
86+
}
87+
6388
/*!
6489
* \brief Use the mod_config information to create a graph runtime list.
6590
* \param mod_config The config information that generates by the export library function call.
@@ -115,7 +140,18 @@ std::vector<Module> PipelineExecutor::CreateGraphModules(const ModuleConfig& mod
115140
}
116141
return ret;
117142
}
118-
143+
/*!
144+
* \brief Set a parameter into a graph module.
145+
* \param param_group_name The parameters group name.
146+
* \param param_key_name The parameter key name.
147+
* \param data_in The parameter data.
148+
*/
149+
void PipelineExecutor::SetParam(std::string param_group_name, std::string param_key_name,
150+
DLTensor* data_in) {
151+
// Get the module index from the param name.
152+
int module_index = this->GetParamsGroupPipelineMap(param_group_name);
153+
// TODO(huajsj): set the parameters into runtime module.
154+
}
119155
/*!
120156
* \brief Initialize the pipeline executor with a list of modules to be pipelined
121157
* and config in JSON format.

src/runtime/pipeline/pipeline_executor.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
7575
* \param The global input name.
7676
* \return Returning the index and the input interface name of corresponding subgraph.
7777
*/
78-
Array<String> GetInputPipeplineMapping(std::string input_name);
78+
Array<String> GetInputPipeplineMap(std::string input_name);
79+
/*!
80+
* \brief This function return a module index for the global parameters group name.
81+
* \param name The parameters group name.
82+
* \return Returning a runtime module index.
83+
*/
84+
int GetParamsGroupPipelineMap(const std::string& name);
85+
/*!
86+
* \brief Use the parameters group name to get the specific backend runtime then use
87+
* the param_key_name to set param data for the said backend runtime.
88+
* \param param_group_name The parameters group name.
89+
* \param param_key_name The parameter key name.
90+
* \param data_in The parameter value.
91+
*/
92+
void SetParam(std::string param_group_name, std::string param_key_name, DLTensor* data_in);
7993
/*!
8094
* \brief Get the number of outputs.
8195
*
8296
* \return The number of outputs.
8397
*/
8498
int NumOutputs() const { return num_outputs_; }
85-
8699
/*!\brief Load the module files information.*/
87100
ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) {
88101
reader->BeginArray();
@@ -126,6 +139,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
126139
ConfigPipelineExecution pipeline_config_;
127140
/*!\brief The map of global input and subgraph input.*/
128141
InputConnectionConfig input_connection_config;
142+
/*!\brief The map includes global parameters groups and runtime modules.*/
143+
ParamConnectionConfig param_connection_config;
129144
/*!\brief The module information used to create the graph runtimes.*/
130145
ModuleConfig mod_config_;
131146
/*!\brief How many outputs are in this pipeline executor.*/
@@ -139,6 +154,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
139154
reader->Read(&pipeline_config_);
140155
} else if (key == "input_connection") {
141156
reader->Read(&input_connection_config);
157+
} else if (key == "param_connection") {
158+
reader->Read(&param_connection_config);
142159
} else {
143160
LOG(FATAL) << "do not support key " << key;
144161
}

src/runtime/pipeline/pipeline_struct.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,48 @@ struct InputConnectionConfig {
251251
}
252252
};
253253

254+
/*!
255+
* \brief A map includes global module parameters groups and graph modudles.
256+
*/
257+
struct ParamConnectionConfig {
258+
/*!\brief Mapping from the name of a global module parameters group to the index of a runtime
259+
* module.
260+
*/
261+
std::unordered_map<std::string, int> param_connection;
262+
bool Empty() { return param_connection.empty(); }
263+
int operator[](const std::string key) {
264+
if (param_connection.find(key) == param_connection.end()) {
265+
LOG(FATAL) << "do not support key " << key;
266+
}
267+
return param_connection[key];
268+
}
269+
/*!
270+
* \brief Load from JSONReader.
271+
* \param reader Json reader.
272+
*/
273+
void Load(dmlc::JSONReader* reader) {
274+
reader->BeginArray();
275+
while (reader->NextArrayItem()) {
276+
reader->BeginObject();
277+
std::string key;
278+
std::string global_param_name;
279+
int mod_idx = -1;
280+
while (reader->NextObjectItem(&key)) {
281+
if (key == "global_param_name") {
282+
reader->Read(&global_param_name);
283+
} else if (key == "mod_idx") {
284+
reader->Read(&mod_idx);
285+
} else {
286+
LOG(FATAL) << "do not support key " << key;
287+
}
288+
}
289+
ICHECK(mod_idx >= 0) << "Invalid module index value " << mod_idx;
290+
ICHECK(!global_param_name.empty()) << "Invalid global parameter group name value";
291+
param_connection[global_param_name] = mod_idx;
292+
}
293+
}
294+
};
295+
254296
/*!
255297
* \brief The information used to initialize the graph executor module, the information
256298
* come from the export library function call.

0 commit comments

Comments
 (0)