Skip to content

Commit b139736

Browse files
[FIX][RUNTIME] Convert container with function value type (#14024)
Prior to this PR, though the `convert` function is capable of converting a single Python function/lambda to TVM func, it is not able to convert a container whose values inside are functions to TVM object. This PR adds function conversion to `convert_to_object` and redirects `convert` to `convert_to_object`, so that now the conversion is always recursive, and therefore will work well on function container value type. Co-authored-by: Chaofan Lin <[email protected]>
1 parent d12a636 commit b139736

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

python/tvm/runtime/object_generic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def asobject(self):
3535
raise NotImplementedError()
3636

3737

38-
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
38+
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject)
3939

4040

4141
def convert_to_object(value, span=None):
@@ -79,6 +79,8 @@ def convert_to_object(value, span=None):
7979
return _ffi_api.Map(*vlist)
8080
if isinstance(value, ObjectGeneric):
8181
return value.asobject()
82+
if callable(value):
83+
return convert_to_tvm_func(value)
8284
if value is None:
8385
return None
8486

@@ -99,13 +101,12 @@ def convert(value, span=None):
99101
-------
100102
tvm_val : Object or Function
101103
Converted value in TVM
102-
"""
103-
if isinstance(value, (PackedFuncBase, ObjectBase)):
104-
return value
105-
106-
if callable(value):
107-
return convert_to_tvm_func(value)
108104
105+
Note
106+
----
107+
This function is redirected to `convert_to_object` as it is widely used in
108+
the codebase. We can choose one to keep and discard the other one later.
109+
"""
109110
return convert_to_object(value, span=span)
110111

111112

tests/python/all-platform-minimal-test/test_runtime_packed_func.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,18 @@ def check(arr):
153153
assert tvm.testing.object_use_count(x) == 1
154154

155155

156+
def test_dict_function_value_type():
157+
from tvm import tir # pylint: disable=import-outside-toplevel
158+
159+
te_func_dict = {"add": lambda a, b: a + b}
160+
161+
converted_dict = tvm.runtime.convert(te_func_dict)
162+
f = converted_dict["add"]
163+
a = tir.Var("a", "float32")
164+
b = tir.Var("b", "float32")
165+
tvm.ir.assert_structural_equal(f(a, b), tir.Add(a, b))
166+
167+
156168
if __name__ == "__main__":
157169
test_ndarray_args()
158170
test_numpy_scalar()
@@ -164,3 +176,4 @@ def check(arr):
164176
test_return_func()
165177
test_byte_array()
166178
test_device()
179+
test_dict_function_value_type()

0 commit comments

Comments
 (0)