Skip to content

Commit e335da7

Browse files
committed
[runtime] reduce set_input and set_input_zero_copy overhead
1 parent aee16d8 commit e335da7

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

src/runtime/graph/graph_runtime.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131

3232
#include <algorithm>
3333
#include <functional>
34+
#include <memory>
3435
#include <numeric>
35-
#include <vector>
3636
#include <string>
37-
#include <memory>
37+
#include <unordered_set>
3838
#include <utility>
39+
#include <vector>
3940

4041
namespace tvm {
4142
namespace runtime {
@@ -78,18 +79,21 @@ void GraphRuntime::Init(const std::string& graph_json,
7879
ctxs_ = ctxs;
7980
this->SetupStorage();
8081
this->SetupOpExecs();
82+
for (size_t i = 0; i < input_nodes_.size(); i++) {
83+
uint32_t nid = input_nodes_[i];
84+
std::string& name = nodes_[nid].name;
85+
input_map_[name] = i;
86+
}
8187
}
8288
/*!
8389
* \brief Get the input index given the name of input.
8490
* \param name The name of the input.
8591
* \return The index of input.
8692
*/
8793
int GraphRuntime::GetInputIndex(const std::string& name) {
88-
for (size_t i = 0; i< input_nodes_.size(); ++i) {
89-
uint32_t nid = input_nodes_[i];
90-
if (nodes_[nid].name == name) {
91-
return static_cast<int>(i);
92-
}
94+
auto it = input_map_.find(name);
95+
if (it != input_map_.end()) {
96+
return it->second;
9397
}
9498
LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input";
9599
return -1;
@@ -125,16 +129,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
125129
}
126130

127131
// Update the data pointer for each argument of each op
128-
for (auto& op_arg : op_args_) {
129-
if (op_arg) {
130-
const auto it = op_arg->input_entry_ids.find(eid);
131-
if (it != op_arg->input_entry_ids.end()) {
132-
for (const auto i : it->second) {
133-
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle);
134-
t->data = data_ref->data;
135-
}
136-
}
137-
}
132+
for (DLTensor* t : input_dltensors_[eid]) {
133+
t->data = data_ref->data;
138134
}
139135
}
140136
/*!
@@ -324,34 +320,38 @@ void GraphRuntime::SetupStorage() {
324320

325321
void GraphRuntime::SetupOpExecs() {
326322
op_execs_.resize(this->GetNumOfNodes());
327-
op_args_.resize(this->GetNumOfNodes());
323+
input_dltensors_.resize(num_node_entries());
324+
std::unordered_set<uint32_t> input_node_eids;
325+
for (size_t i = 0; i < input_nodes_.size(); i++) {
326+
uint32_t nid = input_nodes_[i];
327+
input_node_eids.insert(entry_id(nid, 0));
328+
}
329+
328330
// setup the array and requirements.
329331
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
330332
const auto& inode = nodes_[nid];
331333
if (inode.op_type == "null") continue;
332334
std::vector<DLTensor> args;
333-
std::vector<uint32_t> input_entry_ids;
334335
for (const auto& e : inode.inputs) {
335336
uint32_t eid = this->entry_id(e);
336337
args.push_back(*(data_entry_[eid].operator->()));
337-
input_entry_ids.push_back(eid);
338338
}
339339
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
340340
uint32_t eid = this->entry_id(nid, index);
341341
args.push_back(*(data_entry_[eid].operator->()));
342342
}
343343
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
344344

345-
std::tie(op_execs_[nid], op_args_[nid]) =
345+
std::shared_ptr<OpArgs> op_args = nullptr;
346+
std::tie(op_execs_[nid], op_args) =
346347
CreateTVMOp(inode.param, args, inode.inputs.size());
347-
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids;
348-
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) {
349-
const auto eid = input_entry_ids[i];
350-
auto it = entry_to_input_pos.find(eid);
351-
if (it == entry_to_input_pos.end()) {
352-
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i});
353-
} else {
354-
it->second.push_back(i);
348+
349+
for (size_t i = 0; i < inode.inputs.size(); i++) {
350+
uint32_t eid = this->entry_id(inode.inputs[i]);
351+
// check if op input is model input
352+
if (input_node_eids.count(eid) > 0) {
353+
input_dltensors_[eid].push_back(
354+
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
355355
}
356356
}
357357
}

src/runtime/graph/graph_runtime.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ struct TVMOpParam {
7070
class GraphRuntime : public ModuleNode {
7171
struct OpArgs {
7272
std::vector<DLTensor> args;
73-
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
7473
std::vector<TVMValue> arg_values;
7574
std::vector<int> arg_tcodes;
7675
std::vector<int64_t> shape_data;
@@ -399,6 +398,10 @@ class GraphRuntime : public ModuleNode {
399398
std::vector<Node> nodes_;
400399
/*! \brief The argument nodes. */
401400
std::vector<uint32_t> input_nodes_;
401+
/*! \brief Map of input names to input indices. */
402+
std::unordered_map<std::string, uint32_t> input_map_;
403+
/*! \brief Used for quick node input DLTensor* lookup given an input eid. */
404+
std::vector<std::vector<DLTensor*>> input_dltensors_;
402405
/*! \brief Used for quick entry indexing. */
403406
std::vector<uint32_t> node_row_ptr_;
404407
/*! \brief Output entries. */
@@ -417,8 +420,6 @@ class GraphRuntime : public ModuleNode {
417420
std::vector<size_t> data_alignment_;
418421
/*! \brief Operator on each node. */
419422
std::vector<std::function<void()> > op_execs_;
420-
/*! \brief Arg info of TVM ops */
421-
std::vector<std::shared_ptr<OpArgs> > op_args_;
422423
};
423424

424425
std::vector<TVMContext> GetAllContext(const TVMArgs& args);

0 commit comments

Comments
 (0)