Skip to content

Commit 09b989d

Browse files
authored
[Rust][Fix] Memory leak (#8714)
* Fix obvious memory leak in function.rs * Update object pointer
1 parent f5315ca commit 09b989d

File tree

7 files changed

+52
-69
lines changed

7 files changed

+52
-69
lines changed

rust/tvm-rt/src/function.rs

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
//! See the tests and examples repository for more examples.
2727
2828
use std::convert::{TryFrom, TryInto};
29+
use std::sync::Arc;
2930
use std::{
3031
ffi::CString,
3132
os::raw::{c_char, c_int},
@@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
3940

4041
pub type Result<T> = std::result::Result<T, Error>;
4142

42-
/// Wrapper around TVM function handle which includes `is_global`
43-
/// indicating whether the function is global or not, and `is_cloned` showing
44-
/// not to drop a cloned function from Rust side.
45-
/// The value of these fields can be accessed through their respective methods.
4643
#[derive(Debug, Hash)]
47-
pub struct Function {
48-
pub(crate) handle: ffi::TVMFunctionHandle,
49-
// whether the registered function is global or not.
50-
is_global: bool,
51-
from_rust: bool,
44+
struct FunctionPtr {
45+
handle: ffi::TVMFunctionHandle,
5246
}
5347

54-
unsafe impl Send for Function {}
55-
unsafe impl Sync for Function {}
48+
// NB(@jroesch): I think this is ok, need to double check,
49+
// if not we should mutex the pointer or move to Rc.
50+
unsafe impl Send for FunctionPtr {}
51+
unsafe impl Sync for FunctionPtr {}
52+
53+
impl FunctionPtr {
54+
fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
55+
FunctionPtr { handle }
56+
}
57+
}
58+
59+
impl Drop for FunctionPtr {
60+
fn drop(&mut self) {
61+
check_call!(ffi::TVMFuncFree(self.handle));
62+
}
63+
}
64+
65+
/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
66+
#[derive(Debug, Hash)]
67+
pub struct Function {
68+
inner: Arc<FunctionPtr>,
69+
}
5670

5771
impl Function {
58-
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
72+
pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
5973
Function {
60-
handle,
61-
is_global: false,
62-
from_rust: false,
74+
inner: Arc::new(FunctionPtr::from_raw(handle)),
6375
}
6476
}
6577

6678
pub unsafe fn null() -> Self {
67-
Function {
68-
handle: std::ptr::null_mut(),
69-
is_global: false,
70-
from_rust: false,
71-
}
79+
Function::from_raw(std::ptr::null_mut())
7280
}
7381

7482
/// For a given function, it returns a function by name.
@@ -84,11 +92,7 @@ impl Function {
8492
if handle.is_null() {
8593
None
8694
} else {
87-
Some(Function {
88-
handle,
89-
is_global: true,
90-
from_rust: false,
91-
})
95+
Some(Function::from_raw(handle))
9296
}
9397
}
9498

@@ -103,12 +107,7 @@ impl Function {
103107

104108
/// Returns the underlying TVM function handle.
105109
pub fn handle(&self) -> ffi::TVMFunctionHandle {
106-
self.handle
107-
}
108-
109-
/// Returns `true` if the underlying TVM function is global and `false` otherwise.
110-
pub fn is_global(&self) -> bool {
111-
self.is_global
110+
self.inner.handle
112111
}
113112

114113
/// Calls the function that created from `Builder`.
@@ -122,7 +121,7 @@ impl Function {
122121

123122
let ret_code = unsafe {
124123
ffi::TVMFuncCall(
125-
self.handle,
124+
self.handle(),
126125
values.as_mut_ptr() as *mut ffi::TVMValue,
127126
type_codes.as_mut_ptr() as *mut c_int,
128127
num_args as c_int,
@@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);
171170

172171
impl Clone for Function {
173172
fn clone(&self) -> Function {
174-
Self {
175-
handle: self.handle,
176-
is_global: self.is_global,
177-
from_rust: true,
173+
Function {
174+
inner: self.inner.clone(),
178175
}
179176
}
180177
}
181178

182-
// impl Drop for Function {
183-
// fn drop(&mut self) {
184-
// if !self.is_global && !self.is_cloned {
185-
// check_call!(ffi::TVMFuncFree(self.handle));
186-
// }
187-
// }
188-
// }
189-
190179
impl From<Function> for RetValue {
191180
fn from(func: Function) -> RetValue {
192-
RetValue::FuncHandle(func.handle)
181+
RetValue::FuncHandle(func.handle())
193182
}
194183
}
195184

@@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {
198187

199188
fn try_from(ret_value: RetValue) -> Result<Function> {
200189
match ret_value {
201-
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
190+
RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
202191
_ => Err(Error::downcast(
203192
format!("{:?}", ret_value),
204193
"FunctionHandle",
@@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {
209198

210199
impl<'a> From<Function> for ArgValue<'a> {
211200
fn from(func: Function) -> ArgValue<'a> {
212-
if func.handle.is_null() {
201+
if func.handle().is_null() {
213202
ArgValue::Null
214203
} else {
215-
ArgValue::FuncHandle(func.handle)
204+
ArgValue::FuncHandle(func.handle())
216205
}
217206
}
218207
}
@@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
222211

223212
fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
224213
match arg_value {
225-
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
214+
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
226215
_ => Err(Error::downcast(
227216
format!("{:?}", arg_value),
228217
"FunctionHandle",
@@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
236225

237226
fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
238227
match arg_value {
239-
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
228+
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)),
240229
_ => Err(Error::downcast(
241230
format!("{:?}", arg_value),
242231
"FunctionHandle",

rust/tvm-rt/src/module.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl Module {
8282
return Err(errors::Error::NullHandle(name.into_string()?.to_string()));
8383
}
8484

85-
Ok(Function::new(fhandle))
85+
Ok(Function::from_raw(fhandle))
8686
}
8787

8888
/// Imports a dependent module such as `.ptx` for cuda gpu.

rust/tvm-rt/src/ndarray.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ use num_traits::Num;
6161

6262
use crate::errors::NDArrayError;
6363

64-
use crate::object::{Object, ObjectPtr};
64+
use crate::object::{Object, ObjectPtr, ObjectRef};
6565

6666
/// See the [`module-level documentation`](../ndarray/index.html) for more details.
6767
#[repr(C)]
@@ -73,7 +73,7 @@ pub struct NDArrayContainer {
7373
// Container Base
7474
dl_tensor: DLTensor,
7575
manager_ctx: *mut c_void,
76-
// TOOD: shape?
76+
shape: ObjectRef,
7777
}
7878

7979
impl NDArrayContainer {

rust/tvm-rt/src/object/object_ptr.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,6 @@ impl Object {
148148
}
149149
}
150150

151-
// impl fmt::Debug for Object {
152-
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153-
// let index =
154-
// format!("{} // key: {}", self.type_index, "the_key");
155-
156-
// f.debug_struct("Object")
157-
// .field("type_index", &index)
158-
// // TODO(@jroesch: do we expose other fields?)
159-
// .finish()
160-
// }
161-
// }
162-
163151
/// An unsafe trait which should be implemented for an object
164152
/// subtype.
165153
///

rust/tvm-rt/src/to_function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ pub trait ToFunction<I, O>: Sized {
7474
&mut fhandle as *mut ffi::TVMFunctionHandle,
7575
));
7676

77-
Function::new(fhandle)
77+
Function::from_raw(fhandle)
7878
}
7979

8080
/// The callback function which is wrapped converted by TVM

rust/tvm-sys/build.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
extern crate bindgen;
2121

22-
use std::{path::{Path, PathBuf}, str::FromStr};
22+
use std::{
23+
path::{Path, PathBuf},
24+
str::FromStr,
25+
};
2326

2427
use anyhow::{Context, Result};
2528
use tvm_build::{BuildConfig, CMakeSetting};
@@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result<TVMInstall> {
195198
if cfg!(feature = "use-vitis-ai") {
196199
build_config.settings.use_vitis_ai = Some(true);
197200
}
198-
if cfg!(any(feature = "static-linking", feature = "build-static-runtime")) {
201+
if cfg!(any(
202+
feature = "static-linking",
203+
feature = "build-static-runtime"
204+
)) {
199205
build_config.settings.build_static_runtime = Some(true);
200206
}
201207

rust/tvm/tests/basics/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn main() {
3535
let mut arr = NDArray::empty(shape, dev, dtype);
3636
arr.copy_from_buffer(data.as_mut_slice());
3737
let ret = NDArray::empty(shape, dev, dtype);
38-
let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
38+
let fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
3939
if !fadd.enabled(dev_name) {
4040
return;
4141
}

0 commit comments

Comments
 (0)