Skip to content

Commit 6cf5739

Browse files
Yinghai Luwweic
authored andcommitted
[Runtime] Enable set_input_zero_copy in GraphRuntime (apache#3416)
* Enable set_input_zero_copy in GraphRuntime * Fix LoadParams * Fix * lint * Fix remote context issue * Fix * Remove LOG * Remove unused variables * Add tests * works * More test scenarios * make it simpler * Remove unnecessary changes * Address comments * More comments * Address comments * Fix build
1 parent ea01cab commit 6cf5739

File tree

4 files changed

+143
-35
lines changed

4 files changed

+143
-35
lines changed

src/runtime/graph/graph_runtime.cc

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,6 +23,7 @@
2323
*/
2424
#include "graph_runtime.h"
2525

26+
#include <tvm/runtime/device_api.h>
2627
#include <tvm/runtime/ndarray.h>
2728
#include <tvm/runtime/packed_func.h>
2829
#include <tvm/runtime/registry.h>
@@ -38,6 +39,13 @@
3839

3940
namespace tvm {
4041
namespace runtime {
42+
namespace details {
43+
inline size_t GetDataAlignment(const DLTensor& arr) {
44+
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
45+
if (align < kAllocAlignment) return kAllocAlignment;
46+
return align;
47+
}
48+
} // namespace details
4149

4250
/*!
4351
* \brief Run all the operations one by one.
@@ -123,6 +131,39 @@ std::string GraphRuntime::GetInputName(int index) const {
123131
std::vector<std::string> GraphRuntime::GetWeightNames() const {
124132
return weight_names_;
125133
}
134+
/*!
135+
* \brief set index-th input to the graph without copying the data.
136+
* \param index The input index.
137+
* \param data_ref The input data that is referred.
138+
*/
139+
void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
140+
CHECK_LT(static_cast<size_t>(index), input_nodes_.size());
141+
uint32_t eid = this->entry_id(input_nodes_[index], 0);
142+
const DLTensor* old_t = data_entry_[eid].operator->();
143+
144+
// check the consistency of input
145+
CHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref));
146+
CHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0);
147+
CHECK_EQ(old_t->ndim, static_cast<size_t>(data_ref->ndim));
148+
CHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type);
149+
CHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id);
150+
for (auto i = 0; i < data_ref->ndim; ++i) {
151+
CHECK_EQ(old_t->shape[i], data_ref->shape[i]);
152+
}
153+
154+
// Update the data pointer for each argument of each op
155+
for (auto& op_arg : op_args_) {
156+
if (op_arg) {
157+
const auto it = op_arg->input_entry_ids.find(eid);
158+
if (it != op_arg->input_entry_ids.end()) {
159+
for (const auto i : it->second) {
160+
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle);
161+
t->data = data_ref->data;
162+
}
163+
}
164+
}
165+
}
166+
}
126167
/*!
127168
* \brief Get the number of outputs
128169
*
@@ -210,7 +251,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
210251
}
211252
}
212253

213-
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
254+
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
214255
uint64_t header, reserved;
215256
CHECK(strm->Read(&header))
216257
<< "Invalid parameters file format";
@@ -232,6 +273,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
232273
CHECK_EQ(data_entry_[eid].use_count(), 1);
233274
data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
234275
CHECK_GT(data_entry_[eid].use_count(), 1);
276+
const DLTensor* tmp = data_entry_[eid].operator->();
277+
data_alignment_[eid] = details::GetDataAlignment(*tmp);
235278
}
236279
this->SetupOpExecs();
237280
}
@@ -294,30 +337,49 @@ void GraphRuntime::SetupStorage() {
294337
// memory assignment for each node entry. The allocated memory on each device
295338
// is mapped to this pool.
296339
data_entry_.resize(num_node_entries());
340+
data_alignment_.resize(num_node_entries());
297341
for (size_t i = 0; i < data_entry_.size(); ++i) {
298342
int storage_id = attrs_.storage_id[i];
299343
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
300344
data_entry_[i] =
301345
storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
346+
const DLTensor* tmp = data_entry_[i].operator->();
347+
data_alignment_[i] = details::GetDataAlignment(*tmp);
302348
}
303349
}
304350

305351
void GraphRuntime::SetupOpExecs() {
306352
op_execs_.resize(this->GetNumOfNodes());
353+
op_args_.resize(this->GetNumOfNodes());
307354
// setup the array and requirements.
308355
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
309356
const auto& inode = nodes_[nid];
310357
if (inode.op_type == "null") continue;
311358
std::vector<DLTensor> args;
359+
std::vector<uint32_t> input_entry_ids;
312360
for (const auto& e : inode.inputs) {
313-
args.push_back(*(data_entry_[this->entry_id(e)].operator->()));
361+
uint32_t eid = this->entry_id(e);
362+
args.push_back(*(data_entry_[eid].operator->()));
363+
input_entry_ids.push_back(eid);
314364
}
315365
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
316366
uint32_t eid = this->entry_id(nid, index);
317367
args.push_back(*(data_entry_[eid].operator->()));
318368
}
369+
319370
if (inode.op_type == "tvm_op") {
320-
op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size());
371+
std::tie(op_execs_[nid], op_args_[nid]) =
372+
CreateTVMOp(inode.param, args, inode.inputs.size());
373+
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids;
374+
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) {
375+
const auto eid = input_entry_ids[i];
376+
auto it = entry_to_input_pos.find(eid);
377+
if (it == entry_to_input_pos.end()) {
378+
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i});
379+
} else {
380+
it->second.push_back(i);
381+
}
382+
}
321383
} else if (inode.op_type == "_tensorrt_subgraph_op") {
322384
#ifdef TVM_GRAPH_RUNTIME_TENSORRT
323385
CHECK_EQ(inode.subgraphs.size(), 1U) << "Only supports one subgraph per node";
@@ -333,25 +395,19 @@ void GraphRuntime::SetupOpExecs() {
333395
}
334396
}
335397

336-
std::function<void()> GraphRuntime::CreateTVMOp(
398+
std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRuntime::CreateTVMOp(
337399
const TVMOpParam& param,
338400
const std::vector<DLTensor>& args,
339401
size_t num_inputs) {
340-
struct OpArgs {
341-
std::vector<DLTensor> args;
342-
std::vector<TVMValue> arg_values;
343-
std::vector<int> arg_tcodes;
344-
std::vector<int64_t> shape_data;
345-
};
346-
std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>();
402+
std::shared_ptr<GraphRuntime::OpArgs> arg_ptr = std::make_shared<GraphRuntime::OpArgs>();
347403
// setup address.
348-
arg_ptr->args = std::move(args);
404+
arg_ptr->args = args;
349405
if (param.flatten_data) {
350406
arg_ptr->shape_data.resize(arg_ptr->args.size());
351407
}
352408
for (size_t i = 0; i < arg_ptr->args.size(); ++i) {
353409
TVMValue v;
354-
DLTensor* t = &(arg_ptr->args[i]);
410+
DLTensor* t = &arg_ptr->args[i];
355411
v.v_handle = t;
356412
arg_ptr->arg_values.push_back(v);
357413
arg_ptr->arg_tcodes.push_back(kArrayHandle);
@@ -364,7 +420,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
364420
}
365421

366422
if (param.func_name == "__nop") {
367-
return [](){};
423+
return {[](){}, arg_ptr};
368424
} else if (param.func_name == "__copy") {
369425
// Perform cross device data copy.
370426
// Directly copy data from the input to the output.
@@ -373,7 +429,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
373429
DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
374430
TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
375431
};
376-
return fexec;
432+
return {fexec, arg_ptr};
377433
}
378434
CHECK(!module_.IsEmpty())
379435
<< "Module cannot be empty in order to get functions from the lib";
@@ -390,7 +446,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
390446
static_cast<int>(arg_ptr->arg_values.size()));
391447
pf.CallPacked(targs, &rv);
392448
};
393-
return fexec;
449+
return {fexec, arg_ptr};
394450
}
395451

396452
PackedFunc GraphRuntime::GetFunction(
@@ -406,14 +462,23 @@ PackedFunc GraphRuntime::GetFunction(
406462
this->SetInput(args[0], args[1]);
407463
}
408464
});
465+
} else if (name == "set_input_zero_copy") {
466+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
467+
if (args[0].type_code() == kStr) {
468+
int in_idx = this->GetInputIndex(args[0]);
469+
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
470+
} else {
471+
this->SetInputZeroCopy(args[0], args[1]);
472+
}
473+
});
409474
} else if (name == "get_output") {
410475
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
411-
if (args.num_args == 2) {
412-
this->CopyOutputTo(args[0], args[1]);
413-
} else {
414-
*rv = this->GetOutput(args[0]);
415-
}
416-
});
476+
if (args.num_args == 2) {
477+
this->CopyOutputTo(args[0], args[1]);
478+
} else {
479+
*rv = this->GetOutput(args[0]);
480+
}
481+
});
417482
} else if (name == "get_input") {
418483
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
419484
int in_idx = 0;

src/runtime/graph/graph_runtime.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -34,6 +34,7 @@
3434
#include <tvm/runtime/packed_func.h>
3535

3636
#include <memory>
37+
#include <unordered_map>
3738
#include <utility>
3839
#include <vector>
3940
#include <string>
@@ -72,6 +73,14 @@ struct TVMOpParam {
7273
* TVM runtime PackedFunc API.
7374
*/
7475
class GraphRuntime : public ModuleNode {
76+
struct OpArgs {
77+
std::vector<DLTensor> args;
78+
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
79+
std::vector<TVMValue> arg_values;
80+
std::vector<int> arg_tcodes;
81+
std::vector<int64_t> shape_data;
82+
};
83+
7584
public:
7685
/*!
7786
* \brief Get member function to front-end
@@ -129,6 +138,12 @@ class GraphRuntime : public ModuleNode {
129138
* \return The name of the index-th input.
130139
*/
131140
std::string GetInputName(int index) const;
141+
/*!
142+
* \brief set index-th input to the graph without copying the data
143+
* \param index The input index.
144+
* \param data_ref The input data that is referred.
145+
*/
146+
void SetInputZeroCopy(int index, DLTensor* data_ref);
132147
/*!
133148
* \brief Get the number of outputs
134149
*
@@ -406,9 +421,9 @@ class GraphRuntime : public ModuleNode {
406421
* \param num_inputs Number of inputs.
407422
* \return The created executor.
408423
*/
409-
std::function<void()> CreateTVMOp(const TVMOpParam& attrs,
410-
const std::vector<DLTensor>& args,
411-
size_t num_inputs);
424+
std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp(
425+
const TVMOpParam& attrs, const std::vector<DLTensor>& args,
426+
size_t num_inputs);
412427
// Get node entry index.
413428
uint32_t entry_id(uint32_t nid, uint32_t index) const {
414429
return node_row_ptr_[nid] + index;
@@ -441,11 +456,16 @@ class GraphRuntime : public ModuleNode {
441456
std::vector<NDArray> storage_pool_;
442457
/*! \brief Data entry of each node. */
443458
std::vector<NDArray> data_entry_;
459+
/*! \brief Data alignment of each node. */
460+
std::vector<size_t> data_alignment_;
444461
/*! \brief Operator on each node. */
445462
std::vector<std::function<void()> > op_execs_;
446463
#ifdef TVM_GRAPH_RUNTIME_TENSORRT
447464
contrib::TensorRTExecManager tensorrt_exec_manager_;
448465
#endif // TVM_GRAPH_RUNTIME_TENSORRT
466+
467+
/*! \brief Arg info of TVM ops */
468+
std::vector<std::shared_ptr<OpArgs> > op_args_;
449469
};
450470

451471
std::vector<TVMContext> GetAllContext(const TVMArgs& args);

src/runtime/ndarray.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY

tests/cpp/relay_build_module_test.cc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,41 @@ TEST(Relay, BuildModule) {
8585
auto ctx = A->ctx;
8686
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
8787
tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
88-
auto set_input_f = run_mod.GetFunction("set_input", false);
88+
auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false);
8989
auto run_f = run_mod.GetFunction("run", false);
9090
auto get_output_f = run_mod.GetFunction("get_output", false);
91-
set_input_f("a", A);
92-
set_input_f("b", B);
93-
set_input_f("c", C);
91+
set_input_f("a", &A.ToDLPack()->dl_tensor);
92+
set_input_f("b", &B.ToDLPack()->dl_tensor);
93+
set_input_f("c", &C.ToDLPack()->dl_tensor);
9494
run_f();
9595
tvm::runtime::NDArray Y = get_output_f(0);
9696
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
9797
for (int i = 0; i < 6; ++i) {
9898
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
9999
}
100+
// mutate the input a bit and run it again
101+
for (int i = 0; i < 6; ++i) {
102+
pB[i] = i + 3;
103+
}
104+
run_f();
105+
tvm::runtime::NDArray Y2 = get_output_f(0);
106+
auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data;
107+
for (int i = 0; i < 6; ++i) {
108+
CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4);
109+
}
110+
// attach a different input and run it again
111+
auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
112+
auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data;
113+
for (int i = 0; i < 6; ++i) {
114+
pC2[i] = i + 4;
115+
}
116+
set_input_f("c", &C2.ToDLPack()->dl_tensor);
117+
run_f();
118+
tvm::runtime::NDArray Y3 = get_output_f(0);
119+
auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data;
120+
for (int i = 0; i < 6; ++i) {
121+
CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4);
122+
}
100123
}
101124

102125
int main(int argc, char ** argv) {

0 commit comments

Comments
 (0)