Skip to content

Commit 6061ae2

Browse files
committed
add OTF input set support
1 parent 12304e8 commit 6061ae2

File tree

10 files changed

+191
-90
lines changed

10 files changed

+191
-90
lines changed

cmake/config.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ set(USE_STACKVM_RUNTIME OFF)
9999
# Whether enable tiny embedded graph executor.
100100
set(USE_GRAPH_EXECUTOR ON)
101101
# Whether enable subgraph runtime.
102-
set(USE_PIPELINE_EXECUTOR ON)
102+
set(USE_PIPELINE_EXECUTOR OFF)
103103

104104
# Whether enable tiny graph executor with CUDA Graph
105105
set(USE_GRAPH_EXECUTOR_CUDA_GRAPH OFF)

python/tvm/contrib/graph_executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(self, module):
156156
self._run = module["run"]
157157
self._get_output = module["get_output"]
158158
self._get_input = module["get_input"]
159+
self._get_input_index = module["get_input_index"]
159160
self._get_num_outputs = module["get_num_outputs"]
160161
self._get_num_inputs = module["get_num_inputs"]
161162
self._load_params = module["load_params"]
@@ -242,6 +243,16 @@ def get_input(self, index, out=None):
242243

243244
return self._get_input(index)
244245

246+
def get_input_index(self, name):
247+
"""Set inputs to the module via kwargs
248+
249+
Parameters
250+
----------
251+
key : int or str
252+
The input key
253+
"""
254+
return self._get_input_index(name)
255+
245256
def get_output(self, index, out=None):
246257
"""Get index-th output to out
247258

python/tvm/contrib/pipeline_executor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def pipeline_executor_enabled():
25-
""" check if pipeline executor enabled. """
25+
"""check if pipeline executor enabled."""
2626
pipeline_enabled = False
2727
try:
2828
pipelinecreate = tvm._ffi.get_global_func("tvm.pipeline_executor.create")
@@ -153,7 +153,7 @@ def __init__(self, graph_modules, pipeline_config):
153153
self._get_num_outputs = module["get_num_outputs"]
154154
self._get_num_inputs = module["get_num_inputs"]
155155

156-
def set_input(self, key, value, params=None, modindx=0):
156+
def set_input(self, key, value, modindx=1, params=None):
157157
"""Set inputs to the module via kwargs
158158
159159
Parameters
@@ -167,13 +167,13 @@ def set_input(self, key, value, params=None, modindx=0):
167167
params : dict of str to NDArray
168168
Additional arguments
169169
"""
170+
assert modindx >= 1
170171
if key is not None:
171-
self.graph_modules_[modindx].set_input(key, value)
172+
self._set_input(key, tvm.nd.array(value, tvm.cpu()), modindx)
172173

173174
if params:
174175
for param in params:
175-
self.graph_modules_[modindx].set_input(**param)
176-
indx = indx + 1
176+
self.graph_modules_[modindx - 1].set_input(**param)
177177

178178
def run(self):
179179
"""Run forward execution of the graph"""

src/runtime/graph_executor/graph_executor.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name,
502502
dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
503503
this->ShareParams(dynamic_cast<const GraphExecutor&>(*module.operator->()), &strm);
504504
});
505+
} else if (name == "get_input_index") {
506+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
507+
if (String::CanConvertFrom(args[0])) {
508+
*rv = this->GetInputIndex(args[0].operator String());
509+
} else {
510+
*rv = args[0];
511+
}
512+
});
505513
} else {
506514
return PackedFunc();
507515
}

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ void SubGraphRuntime::Stop() { pipeline_stop(runtimes); }
3434
/*!
3535
* \brief Run all the operations one by one.
3636
*/
37-
void SubGraphRuntime::Run() { pipeline_run(runtimes); }
37+
void SubGraphRuntime::Run() {
38+
pipeline_run(runtimes, input_int_map);
39+
/* Clear the input map
40+
*/
41+
}
3842

3943
void SubGraphRuntime::Init(const Array<tvm::runtime::Module>& modules,
4044
const std::string& pipeline_json) {
@@ -52,19 +56,11 @@ void SubGraphRuntime::Init(const Array<tvm::runtime::Module>& modules,
5256
* \param modIndx The runtime index.
5357
*/
5458
void SubGraphRuntime::SetInput(int index, DLTensor* data_in, int modIndx) {
55-
auto gruntime = runtimes[modIndx];
56-
gruntime->runtimePtr->SetInput(index, data_in);
57-
}
58-
59-
/*!
60-
* \brief set index-th input to the modIndx-th graph.
61-
* \param index The input index.
62-
* \param data_in The input data.
63-
* \param modIndx The runtime index.
64-
*/
65-
void SubGraphRuntime::SetInput(const std::string& name, DLTensor* data_in, int modIndx) {
66-
auto gruntime = runtimes[modIndx];
67-
gruntime->runtimePtr->SetInput(name, data_in);
59+
if (1 == modIndx) {
60+
runtimes[0]->runtimePtr->SetInput(index, data_in);
61+
} else {
62+
pipeline_setinput(input_int_map, index, data_in, modIndx);
63+
}
6864
}
6965

7066
/*!
@@ -98,9 +94,15 @@ NDArray SubGraphRuntime::GetInput(int index, int mIndx) const {
9894
return gruntime->runtimePtr->GetInput(index);
9995
}
10096

101-
NDArray SubGraphRuntime::GetInput(const std::string& name, int mIndx) const {
102-
auto gruntime = runtimes[mIndx];
103-
return gruntime->runtimePtr->GetInput(name);
97+
/*!
98+
* \brief Return input index for given input name.
99+
* \param name The input name.
100+
*
101+
* \return int corresponding to given input node name.
102+
*/
103+
int SubGraphRuntime::GetInputIndex(const string& name, int mIndx) const {
104+
auto gruntime = runtimes[mIndx - 1];
105+
return gruntime->runtimePtr->GetInputIndex(name);
104106
}
105107

106108
/*!
@@ -120,7 +122,8 @@ Array<NDArray> SubGraphRuntime::GetOutput(bool syncPoll) {
120122

121123
PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
122124
const ObjectPtr<Object>& sptr_to_self) {
123-
// Return member functions during query.
125+
/* Return member functions during query.
126+
*/
124127
if (name == "set_input") {
125128
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
126129
/* Default use first runtime index value.
@@ -130,7 +133,8 @@ PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
130133
modIndx = static_cast<int>(args[2]);
131134
}
132135
if (String::CanConvertFrom(args[0])) {
133-
this->SetInput(args[0].operator String(), args[1], modIndx);
136+
int index = this->GetInputIndex(args[0].operator String(), modIndx);
137+
this->SetInput(index, args[1], modIndx);
134138
} else {
135139
this->SetInput(static_cast<int>(args[0]), args[1], modIndx);
136140
}
@@ -145,17 +149,18 @@ PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
145149
});
146150
} else if (name == "get_input") {
147151
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
148-
int in_idx = 0, graph_idx = 0;
152+
int in_idx = 0, mod_idx = 0;
149153
if (args.num_args == 2) {
150-
graph_idx = args[1];
154+
mod_idx = args[1];
151155
}
152156

153157
if (String::CanConvertFrom(args[0])) {
154-
*rv = this->GetInput(args[0].operator String(), graph_idx);
158+
int index = this->GetInputIndex(args[0].operator String(), mod_idx);
159+
*rv = this->GetInput(index, mod_idx);
155160
} else {
156161
in_idx = args[0];
157162
if (in_idx >= 0) {
158-
*rv = this->GetInput(in_idx, graph_idx);
163+
*rv = this->GetInput(in_idx, mod_idx);
159164
}
160165
}
161166
});

src/runtime/pipeline/pipeline_executor.h

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace runtime {
4343
*/
4444
class TVM_DLL SubGraphRuntime : public ModuleNode {
4545
public:
46+
SubGraphRuntime() { input_int_map = make_shared<MOD_DLDATA_MAP>(); }
4647
~SubGraphRuntime() {
4748
/* stop pipeline threads and release data in deconstructor.
4849
*/
@@ -82,9 +83,20 @@ class TVM_DLL SubGraphRuntime : public ModuleNode {
8283
* \param data_in The input data.
8384
*/
8485
void SetInput(int index, DLTensor* data_in, int modIndx);
85-
void SetInput(const std::string& name, DLTensor* data_in, int modIndx);
86+
87+
/*!
88+
* \brief get index-th input.
89+
* \param index The input index.
90+
* \return The input data.
91+
*/
8692
NDArray GetInput(int index, int mIndx) const;
87-
NDArray GetInput(const std::string& name, int mIndx) const;
93+
94+
/*!
95+
* \brief get input index-th by name.
96+
* \param input name.
97+
* \return The input index.
98+
*/
99+
int GetInputIndex(const string& name, int mIndx) const;
88100
/*!
89101
* \brief Get the number of outputs
90102
*
@@ -111,7 +123,7 @@ class TVM_DLL SubGraphRuntime : public ModuleNode {
111123
std::string key;
112124
reader->BeginObject();
113125
int mod_indx = 0;
114-
unordered_map<int, unordered_map<int, int>> output;
126+
unordered_map<int, unordered_map<int, string>> output;
115127
while (reader->NextObjectItem(&key)) {
116128
if (key == "mod_indx") {
117129
reader->Read(&mod_indx);
@@ -120,27 +132,28 @@ class TVM_DLL SubGraphRuntime : public ModuleNode {
120132
reader->BeginArray();
121133
while (reader->NextArrayItem()) {
122134
int output_indx = -1;
123-
unordered_map<int, int> depend;
135+
unordered_map<int, string> depend;
124136
reader->BeginObject();
125137
while (reader->NextObjectItem(&key)) {
126138
if (key == "output_indx") {
127139
reader->Read(&output_indx);
128140
}
129141
if (key == "dependent") {
130142
reader->BeginArray();
131-
int dep_mod_indx = -1, input_indx = -1;
143+
int dep_mod_indx = -1;
144+
string inputName;
132145
while (reader->NextArrayItem()) {
133146
reader->BeginObject();
134147
while (reader->NextObjectItem(&key)) {
135148
if (key == "mod_indx") {
136149
reader->Read(&dep_mod_indx);
137150
}
138-
if (key == "input_indx") {
139-
reader->Read(&input_indx);
151+
if (key == "input_name") {
152+
reader->Read(&inputName);
140153
}
141154
}
142-
if (dep_mod_indx >= 0 && input_indx >= 0) {
143-
depend[dep_mod_indx] = input_indx;
155+
if (dep_mod_indx >= 0) {
156+
depend[dep_mod_indx] = inputName;
144157
}
145158
}
146159
}
@@ -162,6 +175,7 @@ class TVM_DLL SubGraphRuntime : public ModuleNode {
162175
vector<NDArray> output_entry_;
163176
PIPELINE_CONF pipeline_conf;
164177
vector<shared_ptr<RuntimeItem>> runtimes;
178+
MOD_DLDATA_MAP_PTR input_int_map;
165179
size_t outpuNumber = 0;
166180
};
167181
} // namespace runtime

src/runtime/pipeline/pipeline_function.cc

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void pipeline_pipeline_run(const int& num, const shared_ptr<RuntimeItem>& curRun
3939
curRunItem->Run();
4040

4141
vector<shared_ptr<OutputData>> outputs;
42-
curRunItem->GetOutput2(&outputs);
42+
curRunItem->GetOutput(&outputs);
4343
pipeline_queue_push(nextQueue, &outputs);
4444
curRunItem->notifyDataReadyToNext();
4545
}
@@ -53,6 +53,24 @@ thread* pipeline_pipeline_init(SHARED_RUNTIME_VEC* runtimes) {
5353
return NULL;
5454
}
5555

56+
RUNTIME_PIPELINE_OUTPUT_CONF
57+
pipeline_name_to_indx(const Array<Module>& graphRuntimes,
58+
const RUNTIME_PIPELINE_OUTPUT_CONF_STR& pConfStr) {
59+
RUNTIME_PIPELINE_OUTPUT_CONF confRet;
60+
for (auto outConf : pConfStr) {
61+
for (auto conf : outConf.second) {
62+
int modIndx = conf.first;
63+
if (modIndx) {
64+
auto mGetIndex = ((Module)graphRuntimes[modIndx - 1]).GetFunction("get_input_index");
65+
confRet[outConf.first][modIndx] = (static_cast<int>(mGetIndex(conf.second))) + 1;
66+
} else {
67+
confRet[outConf.first][modIndx] = stoi(conf.second);
68+
}
69+
}
70+
}
71+
return confRet;
72+
}
73+
5674
size_t pipeline_init(Array<Module> graphRuntimes, SHARED_RUNTIME_VEC* runtimes,
5775
PIPELINE_CONF* pipeline_conf) {
5876
int outputNum = 0;
@@ -62,7 +80,10 @@ size_t pipeline_init(Array<Module> graphRuntimes, SHARED_RUNTIME_VEC* runtimes,
6280
/* runtimeIndx start from 1.
6381
*/
6482
int runtimeIndx = i + 1;
65-
auto& pConf = pipeline_conf->at(runtimeIndx);
83+
/* get dependency configuration information.
84+
*/
85+
auto pConf = pipeline_name_to_indx(graphRuntimes, pipeline_conf->at(runtimeIndx));
86+
6687
auto runItem = make_shared<RuntimeItem>(graphRuntimes[i], sub_queue, &pConf, runtimeIndx);
6788
runtimes->push_back(runItem);
6889
/* set prev and next for RuntimeItem, runtime need these information to
@@ -99,12 +120,23 @@ bool pipeline_queue_poll(QUEUE* queue, RuntimeData* runtimeData) {
99120
return q_poll<SLOT, RuntimeData>(queue, runtimeData);
100121
}
101122

102-
void pipeline_run(const SHARED_RUNTIME_VEC& runtimes) {
123+
void pipeline_run(const SHARED_RUNTIME_VEC& runtimes, const MOD_DLDATA_MAP_PTR indxInputs) {
103124
shared_ptr<RuntimeItem> runtime = runtimes.front();
104125
runtime->Run();
105-
126+
/* Get runtime output
127+
*/
106128
vector<shared_ptr<OutputData>> outputs;
107-
runtime->GetOutput2(&outputs);
129+
runtime->GetOutput(&outputs);
130+
131+
/* Storage input data for runtimes after first runtime
132+
*/
133+
for (auto modInputs : *indxInputs) {
134+
int modIndx = modInputs.first;
135+
for (auto inputs : modInputs.second) {
136+
outputs.push_back(make_shared<OutputData>(modIndx, inputs.first + 1, inputs.second->data));
137+
}
138+
}
139+
108140
pipeline_queue_push(runtime->next->queue, &outputs);
109141
runtime->notifyDataReadyToNext();
110142
return;
@@ -133,3 +165,21 @@ void pipeline_stop(const SHARED_RUNTIME_VEC& runtimes) {
133165
runtimes.front()->notifyNextExit();
134166
}
135167
}
168+
169+
void pipeline_setinput(MOD_DLDATA_MAP_PTR input_int_map, const int index, const DLTensor* data_in,
170+
const int modIndx) {
171+
if (input_int_map->find(modIndx) == input_int_map->end()) {
172+
DLDATA_MAP dlmap;
173+
dlmap[index] = nullptr;
174+
input_int_map->insert({modIndx, dlmap});
175+
} else if (input_int_map->at(modIndx).find(index) == input_int_map->at(modIndx).end()) {
176+
input_int_map->at(modIndx)[index] = nullptr;
177+
}
178+
179+
TENSOR_DATA tensor_data = input_int_map->at(modIndx)[index];
180+
if (tensor_data == nullptr) {
181+
tensor_data = make_shared<TensorData>();
182+
input_int_map->at(modIndx)[index] = tensor_data;
183+
}
184+
tensor_data->CreateCopyFrom(data_in, kDLCPU, 0);
185+
}

src/runtime/pipeline/pipeline_function.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_FUNCTION_H_
2020
#define TVM_RUNTIME_PIPELINE_PIPELINE_FUNCTION_H_
2121
#include <memory>
22+
#include <string>
2223
#include <unordered_map>
2324
#include <vector>
2425

@@ -27,15 +28,21 @@
2728
using namespace std;
2829
using namespace tvm::runtime;
2930
typedef vector<shared_ptr<RuntimeItem>> SHARED_RUNTIME_VEC;
30-
typedef unordered_map<int, unordered_map<int, unordered_map<int, int>>> PIPELINE_CONF;
31+
typedef unordered_map<int, unordered_map<int, unordered_map<int, string>>> PIPELINE_CONF;
32+
typedef shared_ptr<TensorData> TENSOR_DATA;
33+
typedef unordered_map<int, TENSOR_DATA> DLDATA_MAP;
34+
typedef unordered_map<int, DLDATA_MAP> MOD_DLDATA_MAP;
35+
typedef shared_ptr<MOD_DLDATA_MAP> MOD_DLDATA_MAP_PTR;
3136

3237
size_t pipeline_init(Array<Module> graphRuntimes, SHARED_RUNTIME_VEC* runtimes,
3338
PIPELINE_CONF* pipeline_conf);
34-
void pipeline_run(const SHARED_RUNTIME_VEC& runtimes);
39+
void pipeline_run(const SHARED_RUNTIME_VEC& runtimes, const MOD_DLDATA_MAP_PTR indxInputs);
3540
inline void pipeline_queue_push(QUEUE* queue, vector<shared_ptr<OutputData>>* outputs);
3641
bool pipeline_queue_poll(QUEUE* queue, RuntimeData* runtimeData);
3742
bool pipeline_poll(vector<NDArray>* output, const SHARED_RUNTIME_VEC& runtimes,
3843
const bool bSync = false);
3944
void pipeline_stop(const SHARED_RUNTIME_VEC& runtimes);
45+
void pipeline_setinput(MOD_DLDATA_MAP_PTR input_int_map, const int index, const DLTensor* data_in,
46+
const int modIndx);
4047

4148
#endif // TVM_RUNTIME_PIPELINE_PIPELINE_FUNCTION_H_

0 commit comments

Comments
 (0)