Skip to content

Commit acffe5f

Browse files
committed
Add type hint. (apache#20)
1 parent 4d220ea commit acffe5f

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

python/tvm/relax/exec_builder.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
# under the License.
1717

1818
from enum import IntEnum
19+
from typing import Optional, Union, List
1920
import tvm
21+
from tvm._ffi._ctypes.packed_func import TVMRetValueHandle
2022
from tvm.runtime import Object
2123
from tvm._ffi.base import _LIB, check_call
24+
from . vm import Executable
2225
from . import _ffi_api
2326

2427
class SpecialReg(IntEnum):
@@ -40,40 +43,46 @@ def __exit__(self, ptype, value, trace):
4043
@tvm._ffi.register_object("relax.ExecBuilder")
4144
class ExecBuilder(Object):
4245
"""A builder to emit instructions and build executable for the virtual machine."""
43-
def __init__(self):
46+
47+
def __init__(self) -> None:
4448
self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate)
4549

46-
def r(self, idx):
50+
def r(self, idx: int) -> int:
4751
"""set instruction's argument as a register."""
4852
return _ffi_api.ExecBuilderR(self, idx)
4953

50-
def imm(self, value):
54+
def imm(self, value: int) -> int:
5155
"""set instruction's argument as an immediate."""
5256
return _ffi_api.ExecBuilderImm(self, value)
5357

54-
def c(self, idx):
58+
def c(self, idx: int) -> int:
5559
"""set instruction's argument as a constant."""
5660
return _ffi_api.ExecBuilderC(self, idx)
5761

58-
def void_arg(self):
62+
def void_arg(self) -> int:
5963
return self.r(SpecialReg.VOID_ARG)
6064

61-
def vm_state(self):
65+
def vm_state(self) -> int:
6266
return self.r(SpecialReg.VM_STATE)
6367

64-
def function(self, func_name, num_inputs=0):
68+
def function(self, func_name: str, num_inputs: Optional[int] = 0) -> VMFuncScope:
6569
"""annotate a VM function."""
6670
_ffi_api.ExecBuilderFunction(self, func_name, num_inputs)
6771
return VMFuncScope()
6872

69-
def _check_scope(self):
73+
def _check_scope(self) -> None:
7074
if len(VMFuncScope.stack) == 0:
7175
raise ValueError("emit should happen in a function scope")
7276

73-
def emit_constant(self, const):
77+
def emit_constant(self, const: TVMRetValueHandle) -> int:
7478
return _ffi_api.ExecBuilderEmitConstant(self, const)
7579

76-
def emit_call(self, name, args=[], dst=None):
80+
def emit_call(
81+
self,
82+
name: str,
83+
args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = [],
84+
dst: int = None,
85+
) -> None:
7786
"""emit a call instruction which calls a packed function."""
7887
self._check_scope()
7988
if dst is None:
@@ -87,12 +96,11 @@ def emit_call(self, name, args=[], dst=None):
8796
args_.append(arg)
8897
_ffi_api.ExecBuilderEmitCall(self, name, args_, dst)
8998

90-
def emit_ret(self, result):
99+
def emit_ret(self, result: int) -> None:
91100
"""emit a return instruction"""
92101
self._check_scope()
93102
_ffi_api.ExecBuilderEmitRet(self, result)
94103

95-
def get(self):
104+
def get(self) -> Executable:
96105
"""return the executable"""
97106
return _ffi_api.ExecBuilderGet(self)
98-

python/tvm/relax/vm.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from typing import List, Optional, Union, Dict
1819
import tvm
19-
from tvm.runtime import Object
20+
from tvm.runtime import Object, Device, Module, PackedFunc
2021
from tvm._ffi.base import _LIB, check_call
2122
from . import _ffi_api
2223
from ..rpc.base import RPC_SESS_MASK
@@ -25,35 +26,45 @@
2526
@tvm._ffi.register_object("relax.Executable")
2627
class Executable(Object):
2728
"""The executable object emitted by the VM compiler or the ExecBuilder."""
29+
2830
def __init__(self):
2931
self.__init_handle_by_constructor__(_ffi_api.Executable)
3032

31-
def stats(self):
33+
def stats(self) -> str:
3234
"""print the detailed statistics of the executable."""
3335
return _ffi_api.ExecutableStats(self)
3436

35-
def save_to_file(self, file_name):
37+
def save_to_file(self, file_name: str) -> None:
3638
"""serialize and write the executable to a file."""
37-
return _ffi_api.ExecutableSaveToFile(self, file_name)
39+
_ffi_api.ExecutableSaveToFile(self, file_name)
3840

39-
def astext(self):
41+
def astext(self) -> str:
4042
"""print the instructions as text format."""
4143
return _ffi_api.ExecutableAsText(self)
42-
43-
def aspython(self):
44+
45+
def aspython(self) -> str:
4446
"""print the instructions as python program."""
4547
return _ffi_api.ExecutableAsPython(self)
4648

47-
def load_exec_from_file(file_name):
49+
50+
def load_exec_from_file(file_name: str) -> Executable:
4851
return _ffi_api.ExecutableLoadFromFile(file_name)
4952

53+
5054
class VirtualMachine(object):
5155
"""Relax VM runtime."""
5256

5357
NAIVE_ALLOCATOR = 1
5458
POOLED_ALLOCATOR = 2
55-
56-
def __init__(self, exec, device, memory_cfg=None, mod=None):
59+
60+
def __init__(
61+
self,
62+
exec: Executable,
63+
device: Union[Device, List[Device]],
64+
memory_cfg: Optional[Union[str, Dict[Device, str]]] = None,
65+
mod: Optional[Module] = None,
66+
) -> None:
67+
5768
"""
5869
Construct a VirtualMachine wrapper object.
5970
@@ -73,6 +84,9 @@ def __init__(self, exec, device, memory_cfg=None, mod=None):
7384
type specified in the dict, or pooled allocator if not specified in the
7485
dict.
7586
87+
mod : tvm.runtime.Module, optional
88+
Optional runtime module to load to the VM.
89+
7690
Returns
7791
-------
7892
vm: VirtualMachine
@@ -81,7 +95,7 @@ def __init__(self, exec, device, memory_cfg=None, mod=None):
8195
self.module = _ffi_api.VirtualMachine(exec, mod)
8296
self._setup_device(device, memory_cfg)
8397

84-
def _setup_device(self, dev, memory_cfg):
98+
def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None:
8599
"""init devices and allocators."""
86100
devs = dev
87101
if not isinstance(dev, (list, tuple)):
@@ -117,5 +131,5 @@ def _setup_device(self, dev, memory_cfg):
117131
init_args.append(alloc_type)
118132
_ffi_api.VirtualMachineInit(self.module, *init_args)
119133

120-
def __getitem__(self, key):
134+
def __getitem__(self, key: str) -> PackedFunc:
121135
return self.module[key]

src/relax/vm/executable.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ TVM_REGISTER_GLOBAL("relax.ExecutableAsPython").set_body_typed([](Executable exe
462462

463463
TVM_REGISTER_GLOBAL("relax.ExecutableSaveToFile")
464464
.set_body_typed([](Executable exec, std::string file_name) {
465-
return exec->SaveToFile(file_name);
465+
exec->SaveToFile(file_name);
466466
});
467467

468468
TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed([](std::string file_name) {

0 commit comments

Comments
 (0)