2626//! See the tests and examples repository for more examples.
2727
2828use std:: convert:: { TryFrom , TryInto } ;
29+ use std:: sync:: Arc ;
2930use std:: {
3031 ffi:: CString ,
3132 os:: raw:: { c_char, c_int} ,
@@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
3940
4041pub type Result < T > = std:: result:: Result < T , Error > ;
4142
42- /// Wrapper around TVM function handle which includes `is_global`
43- /// indicating whether the function is global or not, and `is_cloned` showing
44- /// not to drop a cloned function from Rust side.
45- /// The value of these fields can be accessed through their respective methods.
4643#[ derive( Debug , Hash ) ]
47- pub struct Function {
48- pub ( crate ) handle : ffi:: TVMFunctionHandle ,
49- // whether the registered function is global or not.
50- is_global : bool ,
51- from_rust : bool ,
44+ struct FunctionPtr {
45+ handle : ffi:: TVMFunctionHandle ,
5246}
5347
54- unsafe impl Send for Function { }
55- unsafe impl Sync for Function { }
48+ // NB(@jroesch): I think this is ok, need to double check,
49+ // if not we should mutex the pointer or move to Rc.
50+ unsafe impl Send for FunctionPtr { }
51+ unsafe impl Sync for FunctionPtr { }
52+
53+ impl FunctionPtr {
54+ fn from_raw ( handle : ffi:: TVMFunctionHandle ) -> Self {
55+ FunctionPtr { handle }
56+ }
57+ }
58+
59+ impl Drop for FunctionPtr {
60+ fn drop ( & mut self ) {
61+ check_call ! ( ffi:: TVMFuncFree ( self . handle) ) ;
62+ }
63+ }
64+
65+ /// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
66+ #[ derive( Debug , Hash ) ]
67+ pub struct Function {
68+ inner : Arc < FunctionPtr > ,
69+ }
5670
5771impl Function {
58- pub ( crate ) fn new ( handle : ffi:: TVMFunctionHandle ) -> Self {
72+ pub ( crate ) fn from_raw ( handle : ffi:: TVMFunctionHandle ) -> Self {
5973 Function {
60- handle,
61- is_global : false ,
62- from_rust : false ,
74+ inner : Arc :: new ( FunctionPtr :: from_raw ( handle) ) ,
6375 }
6476 }
6577
6678 pub unsafe fn null ( ) -> Self {
67- Function {
68- handle : std:: ptr:: null_mut ( ) ,
69- is_global : false ,
70- from_rust : false ,
71- }
79+ Function :: from_raw ( std:: ptr:: null_mut ( ) )
7280 }
7381
7482 /// For a given function, it returns a function by name.
@@ -84,11 +92,7 @@ impl Function {
8492 if handle. is_null ( ) {
8593 None
8694 } else {
87- Some ( Function {
88- handle,
89- is_global : true ,
90- from_rust : false ,
91- } )
95+ Some ( Function :: from_raw ( handle) )
9296 }
9397 }
9498
@@ -103,12 +107,7 @@ impl Function {
103107
104108 /// Returns the underlying TVM function handle.
105109 pub fn handle ( & self ) -> ffi:: TVMFunctionHandle {
106- self . handle
107- }
108-
109- /// Returns `true` if the underlying TVM function is global and `false` otherwise.
110- pub fn is_global ( & self ) -> bool {
111- self . is_global
110+ self . inner . handle
112111 }
113112
114113 /// Calls the function that created from `Builder`.
@@ -122,7 +121,7 @@ impl Function {
122121
123122 let ret_code = unsafe {
124123 ffi:: TVMFuncCall (
125- self . handle ,
124+ self . handle ( ) ,
126125 values. as_mut_ptr ( ) as * mut ffi:: TVMValue ,
127126 type_codes. as_mut_ptr ( ) as * mut c_int ,
128127 num_args as c_int ,
@@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);
171170
172171impl Clone for Function {
173172 fn clone ( & self ) -> Function {
174- Self {
175- handle : self . handle ,
176- is_global : self . is_global ,
177- from_rust : true ,
173+ Function {
174+ inner : self . inner . clone ( ) ,
178175 }
179176 }
180177}
181178
182- // impl Drop for Function {
183- // fn drop(&mut self) {
184- // if !self.is_global && !self.is_cloned {
185- // check_call!(ffi::TVMFuncFree(self.handle));
186- // }
187- // }
188- // }
189-
190179impl From < Function > for RetValue {
191180 fn from ( func : Function ) -> RetValue {
192- RetValue :: FuncHandle ( func. handle )
181+ RetValue :: FuncHandle ( func. handle ( ) )
193182 }
194183}
195184
@@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {
198187
199188 fn try_from ( ret_value : RetValue ) -> Result < Function > {
200189 match ret_value {
201- RetValue :: FuncHandle ( handle) => Ok ( Function :: new ( handle) ) ,
190+ RetValue :: FuncHandle ( handle) => Ok ( Function :: from_raw ( handle) ) ,
202191 _ => Err ( Error :: downcast (
203192 format ! ( "{:?}" , ret_value) ,
204193 "FunctionHandle" ,
@@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {
209198
210199impl < ' a > From < Function > for ArgValue < ' a > {
211200 fn from ( func : Function ) -> ArgValue < ' a > {
212- if func. handle . is_null ( ) {
201+ if func. handle ( ) . is_null ( ) {
213202 ArgValue :: Null
214203 } else {
215- ArgValue :: FuncHandle ( func. handle )
204+ ArgValue :: FuncHandle ( func. handle ( ) )
216205 }
217206 }
218207}
@@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
222211
223212 fn try_from ( arg_value : ArgValue < ' a > ) -> Result < Function > {
224213 match arg_value {
225- ArgValue :: FuncHandle ( handle) => Ok ( Function :: new ( handle) ) ,
214+ ArgValue :: FuncHandle ( handle) => Ok ( Function :: from_raw ( handle) ) ,
226215 _ => Err ( Error :: downcast (
227216 format ! ( "{:?}" , arg_value) ,
228217 "FunctionHandle" ,
@@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
236225
237226 fn try_from ( arg_value : & ArgValue < ' a > ) -> Result < Function > {
238227 match arg_value {
239- ArgValue :: FuncHandle ( handle) => Ok ( Function :: new ( * handle) ) ,
228+ ArgValue :: FuncHandle ( handle) => Ok ( Function :: from_raw ( * handle) ) ,
240229 _ => Err ( Error :: downcast (
241230 format ! ( "{:?}" , arg_value) ,
242231 "FunctionHandle" ,
0 commit comments