@@ -34,7 +34,11 @@ void SubGraphRuntime::Stop() { pipeline_stop(runtimes); }
3434/* !
3535 * \brief Run all the operations one by one.
3636 */
37- void SubGraphRuntime::Run () { pipeline_run (runtimes); }
37+ void SubGraphRuntime::Run () {
38+ pipeline_run (runtimes, input_int_map);
39+ /* Clear the input map
40+ */
41+ }
3842
3943void SubGraphRuntime::Init (const Array<tvm::runtime::Module>& modules,
4044 const std::string& pipeline_json) {
@@ -52,19 +56,11 @@ void SubGraphRuntime::Init(const Array<tvm::runtime::Module>& modules,
5256 * \param modIndx The runtime index.
5357 */
5458void SubGraphRuntime::SetInput (int index, DLTensor* data_in, int modIndx) {
55- auto gruntime = runtimes[modIndx];
56- gruntime->runtimePtr ->SetInput (index, data_in);
57- }
58-
59- /* !
60- * \brief set index-th input to the modIndx-th graph.
61- * \param index The input index.
62- * \param data_in The input data.
63- * \param modIndx The runtime index.
64- */
65- void SubGraphRuntime::SetInput (const std::string& name, DLTensor* data_in, int modIndx) {
66- auto gruntime = runtimes[modIndx];
67- gruntime->runtimePtr ->SetInput (name, data_in);
59+ if (1 == modIndx) {
60+ runtimes[0 ]->runtimePtr ->SetInput (index, data_in);
61+ } else {
62+ pipeline_setinput (input_int_map, index, data_in, modIndx);
63+ }
6864}
6965
7066/* !
@@ -98,9 +94,15 @@ NDArray SubGraphRuntime::GetInput(int index, int mIndx) const {
9894 return gruntime->runtimePtr ->GetInput (index);
9995}
10096
101- NDArray SubGraphRuntime::GetInput (const std::string& name, int mIndx ) const {
102- auto gruntime = runtimes[mIndx ];
103- return gruntime->runtimePtr ->GetInput (name);
97+ /* !
98+ * \brief Return input index for given input name.
99+ * \param name The input name.
100+ *
101+ * \return int corresponding to given input node name.
102+ */
103+ int SubGraphRuntime::GetInputIndex (const string& name, int mIndx ) const {
104+ auto gruntime = runtimes[mIndx - 1 ];
105+ return gruntime->runtimePtr ->GetInputIndex (name);
104106}
105107
106108/* !
@@ -120,7 +122,8 @@ Array<NDArray> SubGraphRuntime::GetOutput(bool syncPoll) {
120122
121123PackedFunc SubGraphRuntime::GetFunction (const std::string& name,
122124 const ObjectPtr<Object>& sptr_to_self) {
123- // Return member functions during query.
125+ /* Return member functions during query.
126+ */
124127 if (name == " set_input" ) {
125128 return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
126129 /* Default use first runtime index value.
@@ -130,7 +133,8 @@ PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
130133 modIndx = static_cast <int >(args[2 ]);
131134 }
132135 if (String::CanConvertFrom (args[0 ])) {
133- this ->SetInput (args[0 ].operator String (), args[1 ], modIndx);
136+ int index = this ->GetInputIndex (args[0 ].operator String (), modIndx);
137+ this ->SetInput (index, args[1 ], modIndx);
134138 } else {
135139 this ->SetInput (static_cast <int >(args[0 ]), args[1 ], modIndx);
136140 }
@@ -145,17 +149,18 @@ PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
145149 });
146150 } else if (name == " get_input" ) {
147151 return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
148- int in_idx = 0 , graph_idx = 0 ;
152+ int in_idx = 0 , mod_idx = 0 ;
149153 if (args.num_args == 2 ) {
150- graph_idx = args[1 ];
154+ mod_idx = args[1 ];
151155 }
152156
153157 if (String::CanConvertFrom (args[0 ])) {
154- *rv = this ->GetInput (args[0 ].operator String (), graph_idx);
158+ int index = this ->GetInputIndex (args[0 ].operator String (), mod_idx);
159+ *rv = this ->GetInput (index, mod_idx);
155160 } else {
156161 in_idx = args[0 ];
157162 if (in_idx >= 0 ) {
158- *rv = this ->GetInput (in_idx, graph_idx );
163+ *rv = this ->GetInput (in_idx, mod_idx );
159164 }
160165 }
161166 });
0 commit comments