Skip to content

Commit 641ce71

Browse files
GraphExecutor: Fix wild pointer assign when input and output are reshape (#17152)
* GraphExecutor: Fix wild pointer assign when input and output are reshape * lint fix --------- Co-authored-by: Yuwei-EdgeCortix <[email protected]>
1 parent 32e9a48 commit 641ce71

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/runtime/graph_executor/graph_executor.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,16 @@ void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) {
230230
// check the consistency of output
231231
CheckExternalDLTensor(data_ref, output_node_eid);
232232

233+
if (nodes_[output_node.node_id].op_type == "tvm_op" &&
234+
nodes_[output_node.node_id].param.func_name == "__nop") {
235+
const NodeEntry& input_node = nodes_[output_node.node_id].inputs[0];
236+
output_node_eid = this->entry_id(input_node);
237+
ICHECK_NE(node_output_dltensors_[output_node_eid].size(), 0);
238+
for (DLTensor* t : node_output_dltensors_[output_node_eid]) {
239+
t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
240+
}
241+
}
242+
233243
// Update the data pointer for output op
234244
for (DLTensor* t : output_dltensors_[output_node_eid]) {
235245
t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
@@ -540,6 +550,13 @@ void GraphExecutor::SetupOpExecs() {
540550
input_dltensors_[input_eid].push_back(
541551
const_cast<DLTensor*>(data_entry_[eid].operator->()));
542552
}
553+
} else {
554+
const auto& arg_node = nodes_[inode.inputs[i].node_id];
555+
if (arg_node.op_type == "tvm_op" && arg_node.param.func_name == "__nop") {
556+
uint32_t arg_input_eid = this->entry_id(arg_node.inputs[0]);
557+
input_dltensors_[arg_input_eid].push_back(
558+
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
559+
}
543560
}
544561
// check if any model output is the input of the op
545562
if (output_node_eids.count(input_eid) > 0) {
@@ -554,6 +571,11 @@ void GraphExecutor::SetupOpExecs() {
554571
if (output_node_eids.count(output_eid) > 0) {
555572
output_dltensors_[output_eid].push_back(
556573
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
574+
} else {
575+
// If the node is not an output, keep its output for record and support set_output_zero_copy
576+
// of reshape __nop nodes.
577+
node_output_dltensors_[output_eid].push_back(
578+
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
557579
}
558580
}
559581
}

src/runtime/graph_executor/graph_executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,8 @@ class TVM_DLL GraphExecutor : public ModuleNode {
464464
std::vector<std::vector<DLTensor*>> output_dltensors_;
465465
/*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */
466466
std::vector<std::vector<DLTensor*>> both_output_opinput_dltensors_;
467+
/*! \brief Used for quick node output DLTensor* lookup given a nop's input eid. */
468+
std::unordered_map<int, std::vector<DLTensor*>> node_output_dltensors_;
467469
/*! \brief Used for quick entry_id lookup given an storage_id. */
468470
std::vector<std::vector<uint32_t>> sid_to_eid_;
469471
/*! \brief Used for quick entry indexing. */

tests/python/runtime/test_runtime_module_based_interface.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,54 @@ def test_graph_module_zero_copy():
735735
tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy())
736736

737737

738+
@tvm.testing.requires_llvm
739+
def test_reshape_zero_copy():
740+
shape0 = (56, 224)
741+
shape1 = (112, 112)
742+
in_name0 = "infeats0"
743+
in_name1 = "infeats1"
744+
x0 = relay.var(in_name0, shape=shape0, dtype="float32")
745+
x0 = relay.reshape(x0, shape1)
746+
747+
x1 = relay.var(in_name1, shape=shape1, dtype="float32")
748+
mat = relay.nn.matmul(x0, x1)
749+
_y = relay.reshape(mat, (-1))
750+
func = relay.Function(relay.analysis.free_vars(_y), _y)
751+
mod = tvm.IRModule.from_expr(func)
752+
753+
with tvm.transform.PassContext(opt_level=3):
754+
lib = relay.build(mod, target="llvm")
755+
m = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
756+
757+
data_ndarray0 = tvm.nd.array(
758+
np.random.random(shape0).astype(np.float32), device=tvm.device("llvm", 0)
759+
)
760+
data_ndarray1 = tvm.nd.array(
761+
np.random.random(shape1).astype(np.float32), device=tvm.device("llvm", 0)
762+
)
763+
764+
def expected():
765+
m.set_input(in_name0, data_ndarray0)
766+
m.set_input(in_name1, data_ndarray1)
767+
m.run()
768+
return m.get_output(0).numpy()
769+
770+
def zero_copy():
771+
from tvm.relay.frontend.common import infer_shape
772+
773+
outshape = infer_shape(_y)
774+
output_view = tvm.nd.empty(outshape, device=tvm.device("llvm", 0))
775+
m.set_input_zero_copy(in_name0, data_ndarray0)
776+
m.set_input_zero_copy(in_name1, data_ndarray1)
777+
m.set_output_zero_copy(0, output_view)
778+
m.run()
779+
return output_view.numpy()
780+
781+
golden_out = expected()
782+
out = zero_copy()
783+
tvm.testing.assert_allclose(golden_out, out)
784+
785+
738786
if __name__ == "__main__":
739787
test_legacy_compatibility()
740788
test_cpu()
@@ -747,3 +795,4 @@ def test_graph_module_zero_copy():
747795
test_cpu_get_graph_params_run()
748796
test_cpu_get_graph_params_compare()
749797
test_graph_module_zero_copy()
798+
test_reshape_zero_copy()

0 commit comments

Comments
 (0)