Skip to content

Commit 301f428

Browse files
authored
[FFI][REFACTOR] Cleanup tvm_ffi python API and types (apache#18277)
This PR cleans up the python API to make things more consistent with existing python array api and torch. Device update - device_id => index, to be consistent with torch - device_type => dlpack_device_type() returns int - added type property same as torch.device API updates: - Move the convenient method like cpu() out into tvm runtime to keep device minimal - tvm_ffi._init_api => tvm_ffi.init_ffi_api - tvm_ffi.register_func => tvm_ffi.register_global_func
1 parent 05daaf8 commit 301f428

File tree

24 files changed

+547
-453
lines changed

24 files changed

+547
-453
lines changed

docs/get_started/quick_start.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
144144
### Working with PyTorch
145145

146146
Atfer build, we will create library such as `build/add_one_cuda.so`, that can be loaded by
147-
with api `tvm_ffi.load_module`. Then the function will become available as property of the loaded module.
148-
The tensor arguments in the ffi functions automatically consumes torch.Tensor. The following code shows how
147+
with api {py:func}`tvm_ffi.load_module` that returns a {py:class}`tvm_ffi.Module`
148+
Then the function will become available as property of the loaded module.
149+
The tensor arguments in the ffi functions automatically consumes `torch.Tensor`. The following code shows how
149150
to use the function in torch.
150151

151152
```python

docs/guides/packaging.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ _LIB = _load_lib()
204204

205205
Effectively, it leverages the `tvm_ffi.load_module` call to load the library
206206
extension DLL shipped along with the package. The `_ffi_api.py` contains a function
207-
call to `tvm_ffi._init_api` that registers all global functions prefixed
207+
call to `tvm_ffi.init_ffi_api` that registers all global functions prefixed
208208
with `my_ffi_extension` into the module.
209209

210210
```python
@@ -214,7 +214,7 @@ from .base import _LIB
214214

215215
# Register all global functions prefixed with 'my_ffi_extension.'
216216
# This makes functions registered via TVM_FFI_STATIC_INIT_BLOCK available
217-
tvm_ffi._init_api("my_ffi_extension", __name__)
217+
tvm_ffi.init_ffi_api("my_ffi_extension", __name__)
218218
```
219219

220220
Then we can redirect the calls to the related functions.

docs/guides/python_guide.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ y = np.empty_like(x)
4747
mod.add_one_cpu(x, y)
4848
```
4949

50-
In this case, `tvm_ffi.load_module` will return a `tvm_ffi.Module` class that contains
50+
In this case, {py:func}`tvm_ffi.load_module` will return a {py:class}`tvm_ffi.Module` class that contains
5151
the exported functions. You can access the functions by their names.
5252

5353
## Tensor
@@ -67,12 +67,12 @@ np_result = np.from_dlpack(tvm_array)
6767

6868
In most cases, however, you do not have to explicitly create Tensors.
6969
The Python interface can take in `torch.Tensor` and `numpy.ndarray` objects
70-
and automatically convert them to `tvm_ffi.Tensor`.
70+
and automatically convert them to {py:class}`tvm_ffi.Tensor`.
7171

7272
## Functions and Callbacks
7373

74-
`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++.
75-
You can retrieve globally registered functions via `tvm_ffi.get_global_func()`.
74+
{py:class}`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++.
75+
You can retrieve globally registered functions via {py:func}`tvm_ffi.get_global_func`.
7676

7777
```python
7878
import tvm_ffi
@@ -84,8 +84,8 @@ assert fecho(1) == 1
8484
```
8585

8686
You can pass a Python function as an argument to another FFI function as callbacks.
87-
Under the hood, `tvm_ffi.convert` is called to convert the Python function into a
88-
`tvm_ffi.Function`.
87+
Under the hood, {py:func}`tvm_ffi.convert` is called to convert the Python function into a
88+
{py:class}`tvm_ffi.Function`.
8989

9090
```python
9191
import tvm_ffi
@@ -103,7 +103,7 @@ You can also register a Python callback as a global function.
103103
```python
104104
import tvm_ffi
105105

106-
@tvm_ffi.register_func("example.add_one")
106+
@tvm_ffi.register_global_func("example.add_one")
107107
def add_one(a):
108108
return a + 1
109109

@@ -112,7 +112,7 @@ assert tvm_ffi.get_global_func("example.add_one")(1) == 2
112112

113113
## Container Types
114114

115-
When an FFI function takes arguments from lists/tuples, they will be converted into `tvm_ffi.Array`.
115+
When an FFI function takes arguments from lists/tuples, they will be converted into {py:class}`tvm_ffi.Array`.
116116

117117
```python
118118
import tvm_ffi
@@ -124,7 +124,7 @@ assert len(arr) == 4
124124
assert arr[0] == 1
125125
```
126126

127-
Dictionaries will be converted to `tvm_ffi.Map`
127+
Dictionaries will be converted to {py:class}`tvm_ffi.Map`
128128

129129
```python
130130
import tvm_ffi
@@ -167,7 +167,7 @@ File "src/ffi/extra/testing.cc", line 60, in void tvm::ffi::TestRaiseError(tvm::
167167
throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0));
168168
```
169169

170-
We register common error kinds. You can also register extra error dispatch via the `tvm_ffi.register_error` function.
170+
We register common error kinds. You can also register extra error dispatch via the {py:func}`tvm_ffi.register_error` function.
171171

172172
## Advanced: Register Your Own Object
173173

@@ -239,5 +239,5 @@ assert test_int_pair.b == 2
239239
Under the hood, we leverage the information registered through the reflection registry to
240240
generate efficient field accessors and methods for each class.
241241

242-
Importantly, when you have multiple inheritance, you need to call `tvm_ffi.register_object`
242+
Importantly, when you have multiple inheritance, you need to call {py:func}`tvm_ffi.register_object`
243243
on both the base class and the child class.

docs/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,9 @@ Apache TVM FFI Documentation
3939
:caption: Concepts
4040

4141
concepts/abi_overview.md
42+
43+
.. toctree::
44+
:maxdepth: 1
45+
:caption: Reference
46+
47+
reference/python/index.rst

docs/reference/python/index.rst

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
Python API
19+
==========
20+
21+
.. automodule:: tvm_ffi
22+
:no-members:
23+
24+
.. currentmodule:: tvm_ffi
25+
26+
Object
27+
------
28+
.. autosummary::
29+
:toctree: generated/
30+
31+
Object
32+
register_object
33+
34+
35+
Function and Module
36+
-------------------
37+
.. autosummary::
38+
:toctree: generated/
39+
40+
41+
Function
42+
Module
43+
register_global_func
44+
get_global_func
45+
system_lib
46+
load_module
47+
init_ffi_api
48+
register_error
49+
convert
50+
51+
52+
Tensor
53+
------
54+
.. autosummary::
55+
:toctree: generated/
56+
57+
Shape
58+
Tensor
59+
Device
60+
from_dlpack
61+
62+
63+
Containers
64+
----------
65+
.. autosummary::
66+
:toctree: generated/
67+
68+
Array
69+
Map

examples/packaging/python/my_ffi_extension/_ffi_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121

2222
# this is a short cut to register all the global functions
2323
# prefixed by `my_ffi_extension.` to this module
24-
tvm_ffi._init_api("my_ffi_extension", __name__)
24+
tvm_ffi.init_ffi_api("my_ffi_extension", __name__)

python/tvm_ffi/__init__.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,43 @@
2020
from . import libinfo
2121

2222
# package init part
23-
from .registry import register_object, register_func, get_global_func, _init_api
24-
from .dtype import dtype, DataTypeCode
25-
from .core import String, Bytes
26-
from .core import Object, ObjectGeneric, Function
27-
from .convert import convert
23+
from .registry import (
24+
register_object,
25+
register_global_func,
26+
get_global_func,
27+
remove_global_func,
28+
init_ffi_api,
29+
)
30+
from ._dtype import dtype
31+
from .core import Object, ObjectConvertible, Function
32+
from ._convert import convert
2833
from .error import register_error
29-
from .tensor import Device, device
30-
from .tensor import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu
31-
from .tensor import from_dlpack, Tensor, Shape
34+
from ._tensor import Device, device, DLDeviceType
35+
from ._tensor import from_dlpack, Tensor, Shape
3236
from .container import Array, Map
33-
from .module import Module, ModulePropertyMask, system_lib, load_module
37+
from .module import Module, system_lib, load_module
3438
from . import serialization
3539
from . import access_path
3640
from . import testing
3741

3842

3943
__all__ = [
4044
"dtype",
41-
"DataTypeCode",
4245
"Device",
4346
"Object",
4447
"register_object",
45-
"register_func",
48+
"register_global_func",
4649
"get_global_func",
47-
"_init_api",
50+
"remove_global_func",
51+
"init_ffi_api",
4852
"Object",
49-
"ObjectGeneric",
53+
"ObjectConvertible",
5054
"Function",
5155
"convert",
52-
"String",
53-
"Bytes",
5456
"register_error",
5557
"Device",
5658
"device",
57-
"cpu",
58-
"cuda",
59-
"rocm",
60-
"opencl",
61-
"metal",
62-
"vpi",
63-
"vulkan",
64-
"ext_dev",
65-
"hexagon",
66-
"webgpu",
59+
"DLDeviceType",
6760
"from_dlpack",
6861
"Tensor",
6962
"Shape",
@@ -73,7 +66,6 @@
7366
"access_path",
7467
"serialization",
7568
"Module",
76-
"ModulePropertyMask",
7769
"system_lib",
7870
"load_module",
7971
]

python/tvm_ffi/convert.py renamed to python/tvm_ffi/_convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def convert(value: Any) -> Any:
3333
-------
3434
ffi_obj : Any
3535
The converted TVM FFI object.
36+
37+
Note
38+
----
39+
Function arguments to ffi function calls are
40+
automatically converted. So this function is mainly
41+
only used in internal or testing scenarios.
3642
"""
3743
if isinstance(value, core.Object):
3844
return value
@@ -48,7 +54,7 @@ def convert(value: Any) -> Any:
4854
return core.String(value)
4955
elif isinstance(value, (bytes, bytearray)):
5056
return core.Bytes(value)
51-
elif isinstance(value, core.ObjectGeneric):
57+
elif isinstance(value, core.ObjectConvertible):
5258
return value.asobject()
5359
elif callable(value):
5460
return core._convert_to_ffi_func(value)

python/tvm_ffi/dtype.py renamed to python/tvm_ffi/_dtype.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class DataTypeCode(IntEnum):
25-
"""DataType code in DLTensor."""
25+
"""DLDataTypeCode code in DLTensor."""
2626

2727
INT = 0
2828
UINT = 1
@@ -57,7 +57,7 @@ class dtype(str):
5757

5858
__slots__ = ["__tvm_ffi_dtype__"]
5959

60-
NUMPY_DTYPE_TO_STR = {}
60+
_NUMPY_DTYPE_TO_STR = {}
6161

6262
def __new__(cls, content):
6363
content = str(content)
@@ -111,30 +111,30 @@ def lanes(self):
111111
# although almost in all cases we want numpy
112112
import numpy as np
113113

114-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool"
115-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8"
116-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16"
117-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32"
118-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64"
119-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8"
120-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16"
121-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32"
122-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64"
123-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16"
124-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32"
125-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
114+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool"
115+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8"
116+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16"
117+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32"
118+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64"
119+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8"
120+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16"
121+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32"
122+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64"
123+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16"
124+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32"
125+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
126126
if hasattr(np, "float_"):
127-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64"
127+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64"
128128
except ImportError:
129129
pass
130130

131131
try:
132132
import ml_dtypes
133133

134-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
135-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
136-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
137-
dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
134+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
135+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
136+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
137+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
138138
except ImportError:
139139
pass
140140

python/tvm_ffi/_ffi_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""FFI API."""
18-
from .registry import _init_api
18+
from . import registry
1919

20-
21-
_init_api("ffi", __name__)
20+
registry.init_ffi_api("ffi", __name__)

0 commit comments

Comments
 (0)