@@ -54,6 +54,8 @@ def build(pipe_configs):
5454 raise RuntimeError ('"module_connection" is missing' )
5555 if "input_connection" not in config :
5656 raise RuntimeError ('"input_connection" is missing' )
57+ if "param_connection" not in config :
58+ raise RuntimeError ('"param_connection" is missing' )
5759
5860 mod_n_configs = config ["module_connection" ]
5961 config_len = len (mod_n_configs )
@@ -91,6 +93,7 @@ def build(pipe_configs):
9193 # map of global input and subgraph input, and the "module_connection" is used to
9294 # record module dependency.
9395 string_config = {}
96+ string_config ["param_connection" ] = config ["param_connection" ]
9497 string_config ["input_connection" ] = config ["input_connection" ]
9598 string_config ["module_connection" ] = module_string_config
9699
@@ -114,6 +117,8 @@ def __init__(self, module):
114117 # Get the packed functions from the pipeline executor.
115118 self ._get_num_outputs = self .module ["get_num_outputs" ]
116119 self ._get_input_pipeline_map = self .module ["get_input_pipeline_map" ]
120+ self ._get_params_group_pipeline_map = self .module ["get_params_group_pipeline_map" ]
121+ self ._set_param = self .module ["set_param" ]
117122
118123 def get_input_pipeline_map (self , name ):
119124 """Using the "name" to get the corresponding subgraph index and also get the "input name"
@@ -125,6 +130,39 @@ def get_input_pipeline_map(self, name):
125130 """
126131 return self ._get_input_pipeline_map (name )
127132
133+ def get_params_group_pipeline_map (self , name ):
134+ """Use the name of the parameters group to get the corresponding runtime module index.
135+
136+ Parameters
137+ ----------
138+ name: str
139+ The parameter group name.
140+
141+ Returns
142+ -------
143+ module_index: int
144+ The index of the runtime module.
145+ """
146+ return self ._get_params_group_pipeline_map (name )
147+
148+ def set_params (self , params_group_name , params_data ):
149+ """Set the parameter group value given the parameter group name. Note that the parameter
150+ group name is declared in the pipeline executor config.
151+
152+ Parameters
153+ ----------
154+ params_group_name : str
155+ The parameters group name.
156+
157+ params_data : Dict[str, NDArray]
158+ A map from parameter name to data.
159+ """
160+ if not params_data :
161+ raise RuntimeError ('"params_data is empty!"' )
162+
163+ for key , val in params_data .items ():
164+ self ._set_param (params_group_name , key , val )
165+
128166 @property
129167 def num_outputs (self ):
130168 """Get the number of outputs.
@@ -311,9 +349,19 @@ def connect(self, binding):
311349 if self .io_owner == binding .io_owner :
312350 raise RuntimeError ("Can not bind itself." )
313351
352+ if self .io_type == "param" and not self .is_pipeline_executor_interface ():
353+ raise RuntimeError (
354+ 'The "param" binding can only be used by a pipeline executor interface!'
355+ )
356+
314357 if not self .is_pipeline_executor_interface () and self .io_type == "input" :
315358 raise RuntimeError ("Module can only bind from output interface!" )
316359
360+ if self .io_type == "param" and binding .io_type != "param" :
361+ raise RuntimeError (
362+ 'A global "param" interface can only be bind with a module "param" interface!'
363+ )
364+
317365 if (
318366 not self .is_pipeline_executor_interface ()
319367 and not binding .is_pipeline_executor_interface ()
@@ -412,6 +460,7 @@ def __init__(self, mod=None):
412460 self .output_type = InferType ()(mod )["main" ].checked_type .ret_type
413461 self .input_bindings = PipelineConfig .BindingList (self , "input" )
414462 self .output_bindings = PipelineConfig .BindingList (self , "output" )
463+ self .param_binding = PipelineConfig .Binding (self , "param" , "param" )
415464
416465 def __eq__ (self , other ):
417466 if isinstance (other , PipelineConfig .ModuleWrapper ):
@@ -427,6 +476,9 @@ def __getitem__(self, key):
427476 if key == "output" :
428477 return self .output_bindings
429478
479+ if key == "param" :
480+ return self .param_binding
481+
430482 raise RuntimeError (f"{ key } not found!" )
431483
432484 raise RuntimeError ('The data type of "key" is not supported!' )
@@ -483,14 +535,21 @@ def __init__(self):
483535 self .mod_wrapper = {}
484536 self .input_bindings = self .BindingList (self , "input" )
485537 self .output_bindings = self .BindingList (self , "output" )
538+ # There is a map of global parameters group and module index.
539+ self .param_group_bindings = self .BindingList (self , "param" )
486540
487541 def __str__ (self ):
488542 # Get configuration information as a string.
489543
490544 # Use topological sort to get correct module order.
491545 self .dag_topology_sort ()
546+ # Getting the parameters dependencies.
547+ param_dump = "Params\n "
548+ for param_name in self .param_group_bindings .bindings :
549+ inf = self .param_group_bindings .bindings [param_name ]
550+ param_dump += str (inf ) + "\n "
492551 # Get the input dependencies.
493- input_dump = "Inputs \n "
552+ input_dump = "\n Inputs \n "
494553 for input_name in self .input_bindings .bindings :
495554 inf = self .input_bindings .bindings [input_name ]
496555 input_dump += str (inf ) + "\n "
@@ -516,7 +575,7 @@ def __str__(self):
516575 for name in sorted (output .keys ()):
517576 output_dump += f" |output({ name } ) : { output [name ]} \n "
518577
519- return input_dump + output_dump + connections_dump
578+ return param_dump + input_dump + output_dump + connections_dump
520579
521580 def __getitem__ (self , key ):
522581 if isinstance (key , tvm .ir .module .IRModule ):
@@ -529,8 +588,12 @@ def __getitem__(self, key):
529588 return self .input_bindings
530589 if key == "output" :
531590 return self .output_bindings
591+ if key == "param_group" :
592+ return self .param_group_bindings
593+
594+ raise RuntimeError (f"{ key } not found!" )
532595
533- raise RuntimeError (f" { key } not found." )
596+ raise RuntimeError (f'The key type " { type ( key ) } " is not supported!' )
534597
535598 def get_config (self ):
536599 """Get the configuration information in dictionary form, this configuration
@@ -541,7 +604,6 @@ def get_config(self):
541604 self .dag_topology_sort ()
542605 mconfig = {}
543606 module_connection = {}
544- input_connection = {}
545607 for mod in self .mod_wrapper :
546608 # Generate pipeline configuration.
547609 mconf = {}
@@ -579,22 +641,33 @@ def get_config(self):
579641 "dev" : module .dev ,
580642 }
581643
582- # Create a map of pipeline input and subgraph input.
583- input_connection = []
584- for input_name in self .input_bindings .bindings :
585- input_dict = self .input_bindings .bindings [input_name ].get_binding_dict ()
586- if "interface_name" not in input_dict ["connection" ][0 ]:
587- raise RuntimeError ("interface_name is missing in connection config!" )
588- # Creating the map of global interface and subgraph interface.
589- input_map = {
590- "global_interface_name" : input_dict ["interface_name" ],
591- "mod_idx" : input_dict ["connection" ][0 ]["mod_idx" ],
592- "module_interface_name" : input_dict ["connection" ][0 ]["interface_name" ],
593- }
594- input_connection .append (input_map )
644+ # Creating a map including pipeline inputs and subgraph inputs.
645+ input_connection = []
646+ for input_name in self .input_bindings .bindings :
647+ input_dict = self .input_bindings .bindings [input_name ].get_binding_dict ()
648+ if "interface_name" not in input_dict ["connection" ][0 ]:
649+ raise RuntimeError ("interface_name is missing in connection config!" )
650+ # Creating the map including global interfaces and subgraph interfaces.
651+ input_map = {
652+ "global_interface_name" : input_dict ["interface_name" ],
653+ "mod_idx" : input_dict ["connection" ][0 ]["mod_idx" ],
654+ "module_interface_name" : input_dict ["connection" ][0 ]["interface_name" ],
655+ }
656+ input_connection .append (input_map )
657+
658+ # Create a map including global parameters groups and modules.
659+ param_connection = []
660+ for param_name in self .param_group_bindings .bindings :
661+ param_dict = self .param_group_bindings .bindings [param_name ].get_binding_dict ()
662+ param_map = {
663+ "global_param_name" : param_dict ["interface_name" ],
664+ "mod_idx" : param_dict ["connection" ][0 ]["mod_idx" ],
665+ }
666+ param_connection .append (param_map )
595667
596668 mconfig ["module_connection" ] = module_connection
597669 mconfig ["input_connection" ] = input_connection
670+ mconfig ["param_connection" ] = param_connection
598671 return mconfig
599672
600673 def dag_topology_sort (self ):
@@ -613,8 +686,12 @@ def dag_topology_sort(self):
613686
614687 mlist += temp_list
615688
689+ mod_wrapper_sort = {}
616690 for mod , i in zip (mlist , range (len (mlist ))):
617691 self .mod_wrapper [mod ].set_idx_name (i )
692+ mod_wrapper_sort [mod ] = self .mod_wrapper [mod ]
693+
694+ self .mod_wrapper = mod_wrapper_sort
618695
619696 def get_mod_idx (self , mod ):
620697 # Return the module index.
0 commit comments