66 * to you under the Apache License, Version 2.0 (the
77 * "License"); you may not use this file except in compliance
88 * with the License. You may obtain a copy of the License at
9- *
9+ *
1010 * http://www.apache.org/licenses/LICENSE-2.0
11- *
11+ *
1212 * Unless required by applicable law or agreed to in writing,
1313 * software distributed under the License is distributed on an
1414 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
2323 */
2424#include " graph_runtime.h"
2525
26+ #include < tvm/runtime/device_api.h>
2627#include < tvm/runtime/ndarray.h>
2728#include < tvm/runtime/packed_func.h>
2829#include < tvm/runtime/registry.h>
3839
3940namespace tvm {
4041namespace runtime {
42+ namespace details {
43+ inline size_t GetDataAlignment (const DLTensor& arr) {
44+ size_t align = (arr.dtype .bits / 8 ) * arr.dtype .lanes ;
45+ if (align < kAllocAlignment ) return kAllocAlignment ;
46+ return align;
47+ }
48+ } // namespace details
4149
4250/* !
4351 * \brief Run all the operations one by one.
@@ -123,6 +131,39 @@ std::string GraphRuntime::GetInputName(int index) const {
123131std::vector<std::string> GraphRuntime::GetWeightNames () const {
124132 return weight_names_;
125133}
134+ /* !
135+ * \brief set index-th input to the graph without copying the data.
136+ * \param index The input index.
137+ * \param data_ref The input data that is referred.
138+ */
139+ void GraphRuntime::SetInputZeroCopy (int index, DLTensor* data_ref) {
140+ CHECK_LT (static_cast <size_t >(index), input_nodes_.size ());
141+ uint32_t eid = this ->entry_id (input_nodes_[index], 0 );
142+ const DLTensor* old_t = data_entry_[eid].operator ->();
143+
144+ // check the consistency of input
145+ CHECK_EQ (data_alignment_[eid], details::GetDataAlignment (*data_ref));
146+ CHECK_EQ (reinterpret_cast <size_t >(data_ref->data ) % kAllocAlignment , 0 );
147+ CHECK_EQ (old_t ->ndim , static_cast <size_t >(data_ref->ndim ));
148+ CHECK_EQ (old_t ->ctx .device_type , data_ref->ctx .device_type );
149+ CHECK_EQ (old_t ->ctx .device_id , data_ref->ctx .device_id );
150+ for (auto i = 0 ; i < data_ref->ndim ; ++i) {
151+ CHECK_EQ (old_t ->shape [i], data_ref->shape [i]);
152+ }
153+
154+ // Update the data pointer for each argument of each op
155+ for (auto & op_arg : op_args_) {
156+ if (op_arg) {
157+ const auto it = op_arg->input_entry_ids .find (eid);
158+ if (it != op_arg->input_entry_ids .end ()) {
159+ for (const auto i : it->second ) {
160+ DLTensor* t = static_cast <DLTensor*>(op_arg->arg_values [i].v_handle );
161+ t->data = data_ref->data ;
162+ }
163+ }
164+ }
165+ }
166+ }
126167/* !
127168 * \brief Get the number of outputs
128169 *
@@ -210,7 +251,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
210251 }
211252}
212253
213- void GraphRuntime::ShareParams (const GraphRuntime& other, dmlc::Stream* strm) {
254+ void GraphRuntime::ShareParams (const GraphRuntime& other, dmlc::Stream* strm) {
214255 uint64_t header, reserved;
215256 CHECK (strm->Read (&header))
216257 << " Invalid parameters file format" ;
@@ -232,6 +273,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
232273 CHECK_EQ (data_entry_[eid].use_count (), 1 );
233274 data_entry_[eid] = other.GetInput (GetInputIndex (names[i]));
234275 CHECK_GT (data_entry_[eid].use_count (), 1 );
276+ const DLTensor* tmp = data_entry_[eid].operator ->();
277+ data_alignment_[eid] = details::GetDataAlignment (*tmp);
235278 }
236279 this ->SetupOpExecs ();
237280}
@@ -294,30 +337,49 @@ void GraphRuntime::SetupStorage() {
294337 // memory assignment for each node entry. The allocated memory on each device
295338 // is mapped to this pool.
296339 data_entry_.resize (num_node_entries ());
340+ data_alignment_.resize (num_node_entries ());
297341 for (size_t i = 0 ; i < data_entry_.size (); ++i) {
298342 int storage_id = attrs_.storage_id [i];
299343 CHECK_LT (static_cast <size_t >(storage_id), storage_pool_.size ());
300344 data_entry_[i] =
301345 storage_pool_[storage_id].CreateView (attrs_.shape [i], vtype[i]);
346+ const DLTensor* tmp = data_entry_[i].operator ->();
347+ data_alignment_[i] = details::GetDataAlignment (*tmp);
302348 }
303349}
304350
305351void GraphRuntime::SetupOpExecs () {
306352 op_execs_.resize (this ->GetNumOfNodes ());
353+ op_args_.resize (this ->GetNumOfNodes ());
307354 // setup the array and requirements.
308355 for (uint32_t nid = 0 ; nid < this ->GetNumOfNodes (); ++nid) {
309356 const auto & inode = nodes_[nid];
310357 if (inode.op_type == " null" ) continue ;
311358 std::vector<DLTensor> args;
359+ std::vector<uint32_t > input_entry_ids;
312360 for (const auto & e : inode.inputs ) {
313- args.push_back (*(data_entry_[this ->entry_id (e)].operator ->()));
361+ uint32_t eid = this ->entry_id (e);
362+ args.push_back (*(data_entry_[eid].operator ->()));
363+ input_entry_ids.push_back (eid);
314364 }
315365 for (uint32_t index = 0 ; index < inode.param .num_outputs ; ++index) {
316366 uint32_t eid = this ->entry_id (nid, index);
317367 args.push_back (*(data_entry_[eid].operator ->()));
318368 }
369+
319370 if (inode.op_type == " tvm_op" ) {
320- op_execs_[nid] = CreateTVMOp (inode.param , args, inode.inputs .size ());
371+ std::tie (op_execs_[nid], op_args_[nid]) =
372+ CreateTVMOp (inode.param , args, inode.inputs .size ());
373+ auto & entry_to_input_pos = op_args_[nid]->input_entry_ids ;
374+ for (uint32_t i = 0 ; i < input_entry_ids.size (); ++i) {
375+ const auto eid = input_entry_ids[i];
376+ auto it = entry_to_input_pos.find (eid);
377+ if (it == entry_to_input_pos.end ()) {
378+ entry_to_input_pos.emplace (eid, std::vector<uint32_t >{i});
379+ } else {
380+ it->second .push_back (i);
381+ }
382+ }
321383 } else if (inode.op_type == " _tensorrt_subgraph_op" ) {
322384#ifdef TVM_GRAPH_RUNTIME_TENSORRT
323385 CHECK_EQ (inode.subgraphs .size (), 1U ) << " Only supports one subgraph per node" ;
@@ -333,25 +395,19 @@ void GraphRuntime::SetupOpExecs() {
333395 }
334396}
335397
336- std::function<void ()> GraphRuntime::CreateTVMOp (
398+ std::pair<std:: function<void ()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRuntime::CreateTVMOp (
337399 const TVMOpParam& param,
338400 const std::vector<DLTensor>& args,
339401 size_t num_inputs) {
340- struct OpArgs {
341- std::vector<DLTensor> args;
342- std::vector<TVMValue> arg_values;
343- std::vector<int > arg_tcodes;
344- std::vector<int64_t > shape_data;
345- };
346- std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>();
402+ std::shared_ptr<GraphRuntime::OpArgs> arg_ptr = std::make_shared<GraphRuntime::OpArgs>();
347403 // setup address.
348- arg_ptr->args = std::move ( args) ;
404+ arg_ptr->args = args;
349405 if (param.flatten_data ) {
350406 arg_ptr->shape_data .resize (arg_ptr->args .size ());
351407 }
352408 for (size_t i = 0 ; i < arg_ptr->args .size (); ++i) {
353409 TVMValue v;
354- DLTensor* t = &( arg_ptr->args [i]) ;
410+ DLTensor* t = &arg_ptr->args [i];
355411 v.v_handle = t;
356412 arg_ptr->arg_values .push_back (v);
357413 arg_ptr->arg_tcodes .push_back (kArrayHandle );
@@ -364,7 +420,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
364420 }
365421
366422 if (param.func_name == " __nop" ) {
367- return [](){};
423+ return { [](){}, arg_ptr };
368424 } else if (param.func_name == " __copy" ) {
369425 // Perform cross device data copy.
370426 // Directly copy data from the input to the output.
@@ -373,7 +429,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
373429 DLTensor* to = static_cast <DLTensor*>(arg_ptr->arg_values [1 ].v_handle );
374430 TVM_CCALL (TVMArrayCopyFromTo (from, to, nullptr ));
375431 };
376- return fexec;
432+ return { fexec, arg_ptr} ;
377433 }
378434 CHECK (!module_.IsEmpty ())
379435 << " Module cannot be empty in order to get functions from the lib" ;
@@ -390,7 +446,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
390446 static_cast <int >(arg_ptr->arg_values .size ()));
391447 pf.CallPacked (targs, &rv);
392448 };
393- return fexec;
449+ return { fexec, arg_ptr} ;
394450}
395451
396452PackedFunc GraphRuntime::GetFunction (
@@ -406,14 +462,23 @@ PackedFunc GraphRuntime::GetFunction(
406462 this ->SetInput (args[0 ], args[1 ]);
407463 }
408464 });
465+ } else if (name == " set_input_zero_copy" ) {
466+ return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
467+ if (args[0 ].type_code () == kStr ) {
468+ int in_idx = this ->GetInputIndex (args[0 ]);
469+ if (in_idx >= 0 ) this ->SetInputZeroCopy (in_idx, args[1 ]);
470+ } else {
471+ this ->SetInputZeroCopy (args[0 ], args[1 ]);
472+ }
473+ });
409474 } else if (name == " get_output" ) {
410475 return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
411- if (args.num_args == 2 ) {
412- this ->CopyOutputTo (args[0 ], args[1 ]);
413- } else {
414- *rv = this ->GetOutput (args[0 ]);
415- }
416- });
476+ if (args.num_args == 2 ) {
477+ this ->CopyOutputTo (args[0 ], args[1 ]);
478+ } else {
479+ *rv = this ->GetOutput (args[0 ]);
480+ }
481+ });
417482 } else if (name == " get_input" ) {
418483 return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
419484 int in_idx = 0 ;
0 commit comments