Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include <string>
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -78,18 +79,21 @@ void GraphRuntime::Init(const std::string& graph_json,
ctxs_ = ctxs;
this->SetupStorage();
this->SetupOpExecs();
for (size_t i = 0; i < input_nodes_.size(); i++) {
const uint32_t nid = input_nodes_[i];
std::string& name = nodes_[nid].name;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const or just inline this into next line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

input_map_[name] = i;
}
}
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
int GraphRuntime::GetInputIndex(const std::string& name) {
for (size_t i = 0; i< input_nodes_.size(); ++i) {
uint32_t nid = input_nodes_[i];
if (nodes_[nid].name == name) {
return static_cast<int>(i);
}
auto it = input_map_.find(name);
if (it != input_map_.end()) {
return it->second;
}
LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input";
return -1;
Expand Down Expand Up @@ -125,16 +129,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
}

// Update the data pointer for each argument of each op
for (auto& op_arg : op_args_) {
if (op_arg) {
const auto it = op_arg->input_entry_ids.find(eid);
if (it != op_arg->input_entry_ids.end()) {
for (const auto i : it->second) {
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle);
t->data = data_ref->data;
}
}
}
for (DLTensor* t : input_dltensors_[eid]) {
t->data = data_ref->data;
}
}
/*!
Expand Down Expand Up @@ -324,34 +320,38 @@ void GraphRuntime::SetupStorage() {

void GraphRuntime::SetupOpExecs() {
op_execs_.resize(this->GetNumOfNodes());
op_args_.resize(this->GetNumOfNodes());
input_dltensors_.resize(num_node_entries());
std::unordered_set<uint32_t> input_node_eids;
for (size_t i = 0; i < input_nodes_.size(); i++) {
uint32_t nid = input_nodes_[i];
input_node_eids.insert(entry_id(nid, 0));
}

// setup the array and requirements.
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
const auto& inode = nodes_[nid];
if (inode.op_type == "null") continue;
std::vector<DLTensor> args;
std::vector<uint32_t> input_entry_ids;
for (const auto& e : inode.inputs) {
uint32_t eid = this->entry_id(e);
args.push_back(*(data_entry_[eid].operator->()));
input_entry_ids.push_back(eid);
}
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
uint32_t eid = this->entry_id(nid, index);
args.push_back(*(data_entry_[eid].operator->()));
}
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";

std::tie(op_execs_[nid], op_args_[nid]) =
std::shared_ptr<OpArgs> op_args = nullptr;
std::tie(op_execs_[nid], op_args) =
CreateTVMOp(inode.param, args, inode.inputs.size());
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids;
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) {
const auto eid = input_entry_ids[i];
auto it = entry_to_input_pos.find(eid);
if (it == entry_to_input_pos.end()) {
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i});
} else {
it->second.push_back(i);

for (size_t i = 0; i < inode.inputs.size(); i++) {
uint32_t eid = this->entry_id(inode.inputs[i]);
// check if op input is model input
if (input_node_eids.count(eid) > 0) {
input_dltensors_[eid].push_back(
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/graph/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ struct TVMOpParam {
class GraphRuntime : public ModuleNode {
struct OpArgs {
std::vector<DLTensor> args;
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
Expand Down Expand Up @@ -399,6 +398,10 @@ class GraphRuntime : public ModuleNode {
std::vector<Node> nodes_;
/*! \brief The argument nodes. */
std::vector<uint32_t> input_nodes_;
/*! \brief Map of input names to input indices. */
std::unordered_map<std::string, uint32_t> input_map_;
/*! \brief Used for quick node input DLTensor* lookup given an input eid. */
std::vector<std::vector<DLTensor*>> input_dltensors_;
/*! \brief Used for quick entry indexing. */
std::vector<uint32_t> node_row_ptr_;
/*! \brief Output entries. */
Expand All @@ -417,8 +420,6 @@ class GraphRuntime : public ModuleNode {
std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */
std::vector<std::function<void()> > op_execs_;
/*! \brief Arg info of TVM ops */
std::vector<std::shared_ptr<OpArgs> > op_args_;
};

std::vector<TVMContext> GetAllContext(const TVMArgs& args);
Expand Down