Skip to content

Commit 7f9d33b

Browse files
committed
[FFI][REFACTOR] Cleanup tvm_ffi python API and types
This PR cleans up the python API to make things more consistent with existing python array api and torch. Device update - device_id => index, to be consistent with torch - device_type => dlpack_device_type() returns int - added type property same as torch.device API updates: - Move the convenient method like cpu() out into tvm runtime to keep device minimal - tvm_ffi._init_api => tvm_ffi.init_ffi_api - tvm_ffi.register_func => tvm_ffi.register_global_func
1 parent 3c36ce2 commit 7f9d33b

File tree

193 files changed

+994
-760
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

193 files changed

+994
-760
lines changed

apps/ios_rpc/tests/ios_rpc_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
# override metal compiler to compile to iphone
42-
@tvm.register_func("tvm_callback_metal_compile")
42+
@tvm.register_global_func("tvm_callback_metal_compile")
4343
def compile_metal(src, target):
4444
return xcode.compile_metal(src, sdk=sdk)
4545

docs/arch/device_target_interactions.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ then be registered with the following steps.
169169
enum value to a string representation. This string representation
170170
should match the name given to ``GlobalDef().def``.
171171

172-
#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of
172+
#. Add entries to the ``_DEVICE_TYPE_TO_NAME`` and ``_DEVICE_NAME_TO_TYPE`` dictionaries of
173173
:py:class:`tvm.runtime.Device` for the new enum value.
174174

175175

docs/get_started/tutorials/quick_start.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ def forward(self, x):
164164
# .. code-block:: Python
165165
#
166166
# # Convert PyTorch tensor to TVM Tensor
167-
# x_tvm = tvm.runtime.from_dlpack(x_torch.to_dlpack())
167+
# x_tvm = tvm.runtime.from_dlpack(x_torch)
168168
# # Convert TVM Tensor to PyTorch tensor
169-
# x_torch = torch.from_dlpack(x_tvm.to_dlpack())
169+
# x_torch = torch.from_dlpack(x_tvm)
170170
#
171171
# - TVM runtime works in non-python environments, so it works on settings such as mobile
172172
#

ffi/docs/get_started/quick_start.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
144144
### Working with PyTorch
145145

146146
Atfer build, we will create library such as `build/add_one_cuda.so`, that can be loaded by
147-
with api `tvm_ffi.load_module`. Then the function will become available as property of the loaded module.
148-
The tensor arguments in the ffi functions automatically consumes torch.Tensor. The following code shows how
147+
with api {py:func}`tvm_ffi.load_module` that returns a {py:class}`tvm_ffi.Module`
148+
Then the function will become available as property of the loaded module.
149+
The tensor arguments in the ffi functions automatically consumes `torch.Tensor`. The following code shows how
149150
to use the function in torch.
150151

151152
```python

ffi/docs/guides/packaging.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ _LIB = _load_lib()
204204

205205
Effectively, it leverages the `tvm_ffi.load_module` call to load the library
206206
extension DLL shipped along with the package. The `_ffi_api.py` contains a function
207-
call to `tvm_ffi._init_api` that registers all global functions prefixed
207+
call to `tvm_ffi.init_ffi_api` that registers all global functions prefixed
208208
with `my_ffi_extension` into the module.
209209

210210
```python
@@ -214,7 +214,7 @@ from .base import _LIB
214214

215215
# Register all global functions prefixed with 'my_ffi_extension.'
216216
# This makes functions registered via TVM_FFI_STATIC_INIT_BLOCK available
217-
tvm_ffi._init_api("my_ffi_extension", __name__)
217+
tvm_ffi.init_ffi_api("my_ffi_extension", __name__)
218218
```
219219

220220
Then we can redirect the calls to the related functions.

ffi/docs/guides/python_guide.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ y = np.empty_like(x)
4747
mod.add_one_cpu(x, y)
4848
```
4949

50-
In this case, `tvm_ffi.load_module` will return a `tvm_ffi.Module` class that contains
50+
In this case, {py:func}`tvm_ffi.load_module` will return a {py:class}`tvm_ffi.Module` class that contains
5151
the exported functions. You can access the functions by their names.
5252

5353
## Tensor
@@ -67,12 +67,12 @@ np_result = np.from_dlpack(tvm_array)
6767

6868
In most cases, however, you do not have to explicitly create Tensors.
6969
The Python interface can take in `torch.Tensor` and `numpy.ndarray` objects
70-
and automatically convert them to `tvm_ffi.Tensor`.
70+
and automatically convert them to {py:class}`tvm_ffi.Tensor`.
7171

7272
## Functions and Callbacks
7373

74-
`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++.
75-
You can retrieve globally registered functions via `tvm_ffi.get_global_func()`.
74+
{py:class}`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++.
75+
You can retrieve globally registered functions via {py:func}`tvm_ffi.get_global_func`.
7676

7777
```python
7878
import tvm_ffi
@@ -84,8 +84,8 @@ assert fecho(1) == 1
8484
```
8585

8686
You can pass a Python function as an argument to another FFI function as callbacks.
87-
Under the hood, `tvm_ffi.convert` is called to convert the Python function into a
88-
`tvm_ffi.Function`.
87+
Under the hood, {py:func}`tvm_ffi.convert` is called to convert the Python function into a
88+
{py:class}`tvm_ffi.Function`.
8989

9090
```python
9191
import tvm_ffi
@@ -103,7 +103,7 @@ You can also register a Python callback as a global function.
103103
```python
104104
import tvm_ffi
105105

106-
@tvm_ffi.register_func("example.add_one")
106+
@tvm_ffi.register_global_func("example.add_one")
107107
def add_one(a):
108108
return a + 1
109109

@@ -112,7 +112,7 @@ assert tvm_ffi.get_global_func("example.add_one")(1) == 2
112112

113113
## Container Types
114114

115-
When an FFI function takes arguments from lists/tuples, they will be converted into `tvm_ffi.Array`.
115+
When an FFI function takes arguments from lists/tuples, they will be converted into {py:class}`tvm_ffi.Array`.
116116

117117
```python
118118
import tvm_ffi
@@ -124,7 +124,7 @@ assert len(arr) == 4
124124
assert arr[0] == 1
125125
```
126126

127-
Dictionaries will be converted to `tvm_ffi.Map`
127+
Dictionaries will be converted to {py:class}`tvm_ffi.Map`
128128

129129
```python
130130
import tvm_ffi
@@ -167,7 +167,7 @@ File "src/ffi/extra/testing.cc", line 60, in void tvm::ffi::TestRaiseError(tvm::
167167
throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0));
168168
```
169169

170-
We register common error kinds. You can also register extra error dispatch via the `tvm_ffi.register_error` function.
170+
We register common error kinds. You can also register extra error dispatch via the {py:func}`tvm_ffi.register_error` function.
171171

172172
## Advanced: Register Your Own Object
173173

@@ -239,5 +239,5 @@ assert test_int_pair.b == 2
239239
Under the hood, we leverage the information registered through the reflection registry to
240240
generate efficient field accessors and methods for each class.
241241

242-
Importantly, when you have multiple inheritance, you need to call `tvm_ffi.register_object`
242+
Importantly, when you have multiple inheritance, you need to call {py:func}`tvm_ffi.register_object`
243243
on both the base class and the child class.

ffi/docs/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,9 @@ Apache TVM FFI Documentation
3939
:caption: Concepts
4040

4141
concepts/abi_overview.md
42+
43+
.. toctree::
44+
:maxdepth: 1
45+
:caption: Reference
46+
47+
reference/python/index.rst
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
Python API
19+
==========
20+
21+
.. automodule:: tvm_ffi
22+
:no-members:
23+
24+
.. currentmodule:: tvm_ffi
25+
26+
Object
27+
------
28+
.. autosummary::
29+
:toctree: generated/
30+
31+
Object
32+
register_object
33+
34+
35+
Function and Module
36+
-------------------
37+
.. autosummary::
38+
:toctree: generated/
39+
40+
41+
Function
42+
Module
43+
register_global_func
44+
get_global_func
45+
system_lib
46+
load_module
47+
init_ffi_api
48+
register_error
49+
convert
50+
51+
52+
Tensor
53+
------
54+
.. autosummary::
55+
:toctree: generated/
56+
57+
Shape
58+
Tensor
59+
Device
60+
from_dlpack
61+
62+
63+
Containers
64+
----------
65+
.. autosummary::
66+
:toctree: generated/
67+
68+
Array
69+
Map

ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121

2222
# this is a short cut to register all the global functions
2323
# prefixed by `my_ffi_extension.` to this module
24-
tvm_ffi._init_api("my_ffi_extension", __name__)
24+
tvm_ffi.init_ffi_api("my_ffi_extension", __name__)

ffi/python/tvm_ffi/__init__.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,43 @@
2020
from . import libinfo
2121

2222
# package init part
23-
from .registry import register_object, register_func, get_global_func, _init_api
24-
from .dtype import dtype, DataTypeCode
25-
from .core import String, Bytes
26-
from .core import Object, ObjectGeneric, Function
23+
from .registry import (
24+
register_object,
25+
register_global_func,
26+
get_global_func,
27+
remove_global_func,
28+
init_ffi_api,
29+
)
30+
from .dtype import dtype
31+
from .core import Object, ObjectConvertible, Function
2732
from .convert import convert
2833
from .error import register_error
29-
from .tensor import Device, device
30-
from .tensor import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu
34+
from .tensor import Device, device, DLDeviceType
3135
from .tensor import from_dlpack, Tensor, Shape
3236
from .container import Array, Map
33-
from .module import Module, ModulePropertyMask, system_lib, load_module
37+
from .module import Module, system_lib, load_module
3438
from . import serialization
3539
from . import access_path
3640
from . import testing
3741

3842

3943
__all__ = [
4044
"dtype",
41-
"DataTypeCode",
4245
"Device",
4346
"Object",
4447
"register_object",
45-
"register_func",
48+
"register_global_func",
4649
"get_global_func",
47-
"_init_api",
50+
"remove_global_func",
51+
"init_ffi_api",
4852
"Object",
49-
"ObjectGeneric",
53+
"ObjectConvertible",
5054
"Function",
5155
"convert",
52-
"String",
53-
"Bytes",
5456
"register_error",
5557
"Device",
5658
"device",
57-
"cpu",
58-
"cuda",
59-
"rocm",
60-
"opencl",
61-
"metal",
62-
"vpi",
63-
"vulkan",
64-
"ext_dev",
65-
"hexagon",
66-
"webgpu",
59+
"DLDeviceType",
6760
"from_dlpack",
6861
"Tensor",
6962
"Shape",
@@ -73,7 +66,6 @@
7366
"access_path",
7467
"serialization",
7568
"Module",
76-
"ModulePropertyMask",
7769
"system_lib",
7870
"load_module",
7971
]

0 commit comments

Comments
 (0)