|
31 | 31 |
|
32 | 32 | #include <algorithm> |
33 | 33 | #include <functional> |
| 34 | +#include <memory> |
34 | 35 | #include <numeric> |
35 | | -#include <vector> |
36 | 36 | #include <string> |
37 | | -#include <memory> |
| 37 | +#include <unordered_set> |
38 | 38 | #include <utility> |
| 39 | +#include <vector> |
39 | 40 |
|
40 | 41 | namespace tvm { |
41 | 42 | namespace runtime { |
@@ -78,18 +79,21 @@ void GraphRuntime::Init(const std::string& graph_json, |
78 | 79 | ctxs_ = ctxs; |
79 | 80 | this->SetupStorage(); |
80 | 81 | this->SetupOpExecs(); |
| 82 | + for (size_t i = 0; i < input_nodes_.size(); i++) { |
| 83 | + uint32_t nid = input_nodes_[i]; |
| 84 | + std::string& name = nodes_[nid].name; |
| 85 | + input_map_[name] = i; |
| 86 | + } |
81 | 87 | } |
82 | 88 | /*! |
83 | 89 | * \brief Get the input index given the name of input. |
84 | 90 | * \param name The name of the input. |
85 | 91 | * \return The index of input. |
86 | 92 | */ |
87 | 93 | int GraphRuntime::GetInputIndex(const std::string& name) { |
88 | | - for (size_t i = 0; i< input_nodes_.size(); ++i) { |
89 | | - uint32_t nid = input_nodes_[i]; |
90 | | - if (nodes_[nid].name == name) { |
91 | | - return static_cast<int>(i); |
92 | | - } |
| 94 | + auto it = input_map_.find(name); |
| 95 | + if (it != input_map_.end()) { |
| 96 | + return it->second; |
93 | 97 | } |
94 | 98 | LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input"; |
95 | 99 | return -1; |
@@ -125,16 +129,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { |
125 | 129 | } |
126 | 130 |
|
127 | 131 | // Update the data pointer for each argument of each op |
128 | | - for (auto& op_arg : op_args_) { |
129 | | - if (op_arg) { |
130 | | - const auto it = op_arg->input_entry_ids.find(eid); |
131 | | - if (it != op_arg->input_entry_ids.end()) { |
132 | | - for (const auto i : it->second) { |
133 | | - DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle); |
134 | | - t->data = data_ref->data; |
135 | | - } |
136 | | - } |
137 | | - } |
| 132 | + for (DLTensor* t : input_dltensors_[eid]) { |
| 133 | + t->data = data_ref->data; |
138 | 134 | } |
139 | 135 | } |
140 | 136 | /*! |
@@ -324,34 +320,38 @@ void GraphRuntime::SetupStorage() { |
324 | 320 |
|
325 | 321 | void GraphRuntime::SetupOpExecs() { |
326 | 322 | op_execs_.resize(this->GetNumOfNodes()); |
327 | | - op_args_.resize(this->GetNumOfNodes()); |
| 323 | + input_dltensors_.resize(num_node_entries()); |
| 324 | + std::unordered_set<uint32_t> input_node_eids; |
| 325 | + for (size_t i = 0; i < input_nodes_.size(); i++) { |
| 326 | + uint32_t nid = input_nodes_[i]; |
| 327 | + input_node_eids.insert(entry_id(nid, 0)); |
| 328 | + } |
| 329 | + |
328 | 330 | // setup the array and requirements. |
329 | 331 | for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { |
330 | 332 | const auto& inode = nodes_[nid]; |
331 | 333 | if (inode.op_type == "null") continue; |
332 | 334 | std::vector<DLTensor> args; |
333 | | - std::vector<uint32_t> input_entry_ids; |
334 | 335 | for (const auto& e : inode.inputs) { |
335 | 336 | uint32_t eid = this->entry_id(e); |
336 | 337 | args.push_back(*(data_entry_[eid].operator->())); |
337 | | - input_entry_ids.push_back(eid); |
338 | 338 | } |
339 | 339 | for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { |
340 | 340 | uint32_t eid = this->entry_id(nid, index); |
341 | 341 | args.push_back(*(data_entry_[eid].operator->())); |
342 | 342 | } |
343 | 343 | CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; |
344 | 344 |
|
345 | | - std::tie(op_execs_[nid], op_args_[nid]) = |
| 345 | + std::shared_ptr<OpArgs> op_args = nullptr; |
| 346 | + std::tie(op_execs_[nid], op_args) = |
346 | 347 | CreateTVMOp(inode.param, args, inode.inputs.size()); |
347 | | - auto& entry_to_input_pos = op_args_[nid]->input_entry_ids; |
348 | | - for (uint32_t i = 0; i < input_entry_ids.size(); ++i) { |
349 | | - const auto eid = input_entry_ids[i]; |
350 | | - auto it = entry_to_input_pos.find(eid); |
351 | | - if (it == entry_to_input_pos.end()) { |
352 | | - entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i}); |
353 | | - } else { |
354 | | - it->second.push_back(i); |
| 348 | + |
| 349 | + for (size_t i = 0; i < inode.inputs.size(); i++) { |
| 350 | + uint32_t eid = this->entry_id(inode.inputs[i]); |
| 351 | + // check if op input is model input |
| 352 | + if (input_node_eids.count(eid) > 0) { |
| 353 | + input_dltensors_[eid].push_back( |
| 354 | + static_cast<DLTensor*>(op_args->arg_values[i].v_handle)); |
355 | 355 | } |
356 | 356 | } |
357 | 357 | } |
|
0 commit comments