Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 114 additions & 85 deletions ffi/include/tvm/ffi/c_api.h

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions ffi/python/tvm_ffi/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def convert(value: Any) -> Any:
return None
elif hasattr(value, "__dlpack__"):
return core.from_dlpack(
value,
required_alignment=core.__dlpack_auto_import_required_alignment__,
value, required_alignment=core.__dlpack_auto_import_required_alignment__
)
elif isinstance(value, Exception):
return core._convert_to_ffi_error(value)
else:
raise TypeError(f"don't know how to convert type {type(value)} to object")
# in this case, it is an opaque python object
return core._convert_to_opaque_object(value)


core._set_func_convert_to_object(convert)
7 changes: 7 additions & 0 deletions ffi/python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIArray = 71
kTVMFFIMap = 72
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74


ctypedef void* TVMFFIObjectHandle
Expand Down Expand Up @@ -111,6 +112,9 @@ cdef extern from "tvm/ffi/c_api.h":
const char* data
size_t size

ctypedef struct TVMFFIOpaqueObjectCell:
void* handle

ctypedef struct TVMFFIShapeCell:
const int64_t* data
size_t size
Expand Down Expand Up @@ -172,6 +176,8 @@ cdef extern from "tvm/ffi/c_api.h":
const TVMFFITypeMetadata* metadata

int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index,
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) nogil
Expand Down Expand Up @@ -203,6 +209,7 @@ cdef extern from "tvm/ffi/c_api.h":
TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) nogil
TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil
DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil
Expand Down
30 changes: 25 additions & 5 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ cdef inline object make_ret(TVMFFIAny result):
if type_index == kTVMFFINDArray:
# specially handle NDArray as it needs a special dltensor field
return make_ndarray_from_any(result)
elif type_index == kTVMFFIOpaquePyObject:
return make_ret_opaque_object(result)
elif type_index >= kTVMFFIStaticObjectBegin:
return make_ret_object(result)
elif type_index == kTVMFFINone:
Expand Down Expand Up @@ -182,7 +184,10 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
out[i].v_ptr = (<Object>arg).chandle
temp_args.append(arg)
else:
raise TypeError("Unsupported argument type: %s" % type(arg))
arg = _convert_to_opaque_object(arg)
out[i].type_index = kTVMFFIOpaquePyObject
out[i].v_ptr = (<Object>arg).chandle
temp_args.append(arg)


cdef inline int FuncCall3(void* chandle,
Expand Down Expand Up @@ -431,9 +436,9 @@ def _get_global_func(name, allow_missing):


# handle callbacks
cdef void tvm_ffi_callback_deleter(void* fhandle) noexcept with gil:
local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc)
cdef void tvm_ffi_pyobject_deleter(void* fhandle) noexcept with gil:
local_pyobject = <object>(fhandle)
Py_DECREF(local_pyobject)


cdef int tvm_ffi_callback(void* context,
Expand Down Expand Up @@ -468,12 +473,27 @@ def _convert_to_ffi_func(object pyfunc):
CHECK_CALL(TVMFFIFunctionCreate(
<void*>(pyfunc),
tvm_ffi_callback,
tvm_ffi_callback_deleter,
tvm_ffi_pyobject_deleter,
&chandle))
ret = Function.__new__(Function)
(<Object>ret).chandle = chandle
return ret


def _convert_to_opaque_object(object pyobject):
"""Convert a python object to TVM FFI opaque object"""
cdef TVMFFIObjectHandle chandle
Py_INCREF(pyobject)
CHECK_CALL(TVMFFIObjectCreateOpaque(
<void*>(pyobject),
kTVMFFIOpaquePyObject,
tvm_ffi_pyobject_deleter,
&chandle))
ret = OpaquePyObject.__new__(OpaquePyObject)
(<Object>ret).chandle = chandle
return ret


_STR_CONSTRUCTOR = _get_global_func("ffi.String", False)
_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False)
_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True)
Expand Down
17 changes: 17 additions & 0 deletions ffi/python/tvm_ffi/cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ cdef class Object:
(<Object>other).chandle = NULL


cdef class OpaquePyObject(Object):
"""Opaque PyObject container"""
def pyobject(self):
"""Get the underlying python object"""
cdef object obj
cdef PyObject* py_handle
py_handle = <PyObject*>(TVMFFIOpaqueObjectGetCellPtr(self.chandle).handle)
obj = <object>py_handle
return obj


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
Expand Down Expand Up @@ -252,6 +263,12 @@ cdef inline str _type_index_to_key(int32_t tindex):
return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size))


cdef inline object make_ret_opaque_object(TVMFFIAny result):
obj = OpaquePyObject.__new__(OpaquePyObject)
(<Object>obj).chandle = result.v_obj
return obj.pyobject()


cdef inline object make_ret_object(TVMFFIAny result):
global OBJECT_TYPE
cdef int32_t tindex
Expand Down
40 changes: 40 additions & 0 deletions ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>

Expand Down Expand Up @@ -385,6 +386,29 @@ class TypeTable {
Map<String, int64_t> type_attr_name_to_column_index_;
};

/**
* \brief Opaque implementation
*/
class OpaqueObjectImpl : public Object, public TVMFFIOpaqueObjectCell {
public:
OpaqueObjectImpl(void* handle, void (*deleter)(void* handle)) : deleter_(deleter) {
this->handle = handle;
}

void SetTypeIndex(int32_t type_index) {
details::ObjectUnsafe::GetHeader(this)->type_index = type_index;
}

~OpaqueObjectImpl() {
if (deleter_ != nullptr) {
deleter_(handle);
}
}

private:
void (*deleter_)(void* handle);
};

} // namespace ffi
} // namespace tvm

Expand All @@ -400,6 +424,22 @@ int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) {
TVM_FFI_SAFE_CALL_END();
}

int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, void (*deleter)(void* handle),
TVMFFIObjectHandle* out) {
TVM_FFI_SAFE_CALL_BEGIN();
if (type_index != kTVMFFIOpaquePyObject) {
TVM_FFI_THROW(RuntimeError) << "Only kTVMFFIOpaquePyObject is supported for now";
}
// create initial opaque object
tvm::ffi::ObjectPtr<tvm::ffi::OpaqueObjectImpl> p =
tvm::ffi::make_object<tvm::ffi::OpaqueObjectImpl>(handle, deleter);
// need to set the type index after creation, because the set to RuntimeTypeIndex()
// happens after the constructor is called
p->SetTypeIndex(type_index);
*out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(p));
TVM_FFI_SAFE_CALL_END();
}

int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) {
TVM_FFI_SAFE_CALL_BEGIN();
out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key);
Expand Down
25 changes: 25 additions & 0 deletions ffi/tests/cpp/test_object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,29 @@ TEST(Object, WeakObjectPtrAssignment) {
EXPECT_EQ(lock3->value, 777);
}

TEST(Object, OpaqueObject) {
thread_local int deleter_trigger_counter = 0;
struct DummyOpaqueObject {
int value;
DummyOpaqueObject(int value) : value(value) {}

static void Deleter(void* handle) {
deleter_trigger_counter++;
delete static_cast<DummyOpaqueObject*>(handle);
}
};
TVMFFIObjectHandle handle = nullptr;
TVM_FFI_CHECK_SAFE_CALL(TVMFFIObjectCreateOpaque(new DummyOpaqueObject(10), kTVMFFIOpaquePyObject,
DummyOpaqueObject::Deleter, &handle));
ObjectPtr<Object> a =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle));
EXPECT_EQ(a->type_index(), kTVMFFIOpaquePyObject);
EXPECT_EQ(static_cast<DummyOpaqueObject*>(TVMFFIOpaqueObjectGetCellPtr(a.get())->handle)->value,
10);
EXPECT_EQ(a.use_count(), 1);
EXPECT_EQ(deleter_trigger_counter, 0);
a.reset();
EXPECT_EQ(deleter_trigger_counter, 1);
}

} // namespace
22 changes: 22 additions & 0 deletions ffi/tests/python/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ def test_int_map():
assert tuple(amap.values()) == (2, 3)


def test_array_map_of_opaque_object():
class MyObject:
def __init__(self, value):
self.value = value

a = tvm_ffi.convert([MyObject("hello"), MyObject(1)])
assert isinstance(a, tvm_ffi.Array)
assert len(a) == 2
assert isinstance(a[0], MyObject)
assert a[0].value == "hello"
assert isinstance(a[1], MyObject)
assert a[1].value == 1

y = tvm_ffi.convert({"a": MyObject(1), "b": MyObject("hello")})
assert isinstance(y, tvm_ffi.Map)
assert len(y) == 2
assert isinstance(y["a"], MyObject)
assert y["a"].value == 1
assert isinstance(y["b"], MyObject)
assert y["b"].value == "hello"


def test_str_map():
data = []
for i in reversed(range(10)):
Expand Down
25 changes: 25 additions & 0 deletions ffi/tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import gc
import ctypes
import sys
import numpy as np
import tvm_ffi

Expand Down Expand Up @@ -161,3 +162,27 @@ def check1():

check0()
check1()


def test_echo_with_opaque_object():
class MyObject:
def __init__(self, value):
self.value = value

fecho = tvm_ffi.get_global_func("testing.echo")
x = MyObject("hello")
assert sys.getrefcount(x) == 2
y = fecho(x)
assert isinstance(y, MyObject)
assert y is x
assert sys.getrefcount(x) == 3

def py_callback(z):
"""python callback with opaque object"""
assert z is x
return z

fcallback = tvm_ffi.convert(py_callback)
z = fcallback(x)
assert z is x
assert sys.getrefcount(x) == 4
21 changes: 21 additions & 0 deletions ffi/tests/python/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import sys

import tvm_ffi

Expand Down Expand Up @@ -68,3 +69,23 @@ def test_derived_object():

obj0.v_i64 = 21
assert obj0.v_i64 == 21


class MyObject:
def __init__(self, value):
self.value = value


def test_opaque_object():
obj0 = MyObject("hello")
assert sys.getrefcount(obj0) == 2
obj0_converted = tvm_ffi.convert(obj0)
assert sys.getrefcount(obj0) == 3
assert isinstance(obj0_converted, tvm_ffi.core.OpaquePyObject)
obj0_cpy = obj0_converted.pyobject()
assert obj0_cpy is obj0
assert sys.getrefcount(obj0) == 4
obj0_converted = None
assert sys.getrefcount(obj0) == 3
obj0_cpy = None
assert sys.getrefcount(obj0) == 2
Loading