Skip to content

Commit e7c04f5

Browse files
author
Siyuan Feng
authored
[Refactor] Introduce base Executable class and tvm.compile interface (#17710)
This refactor introduces a base Executable class and a `tvm.compile` interface that can be used to compile both TIR and Relax programs. `tvm.compile` will return an Executable object that can be used to call either TIR or Relax functions.
1 parent ec548eb commit e7c04f5

File tree

29 files changed

+764
-237
lines changed

29 files changed

+764
-237
lines changed

docs/how_to/tutorials/optimize_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
426426

427427

428428
with target:
429-
ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm"))
429+
ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
430430
vm = relax.VirtualMachine(ex, dev)
431431

432432

include/tvm/relax/exec_builder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ class ExecBuilderNode : public Object {
125125
/*!
126126
* \brief Raw access to underlying executable build in progress.
127127
*/
128-
vm::Executable* exec() const;
128+
vm::VMExecutable* exec() const;
129129
/*!
130130
* \brief Finalize the build, run formalize and get the final result.
131131
* \note This function should not be called during construction.
132132
*/
133-
ObjectPtr<vm::Executable> Get();
133+
ObjectPtr<vm::VMExecutable> Get();
134134
/*!
135135
* \brief Create an ExecBuilder.
136136
* \return The ExecBuilder.
@@ -165,7 +165,7 @@ class ExecBuilderNode : public Object {
165165
void Formalize();
166166

167167
/*! \brief The mutable internal executable. */
168-
ObjectPtr<vm::Executable> exec_; // mutable
168+
ObjectPtr<vm::VMExecutable> exec_; // mutable
169169
/*! \brief internal dedup map when creating index for a new constant */
170170
std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual> const_dedup_map_;
171171
};

include/tvm/runtime/relax_vm/executable.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ struct VMFuncInfo {
8080
};
8181

8282
/*!
83-
* \brief The executable emitted by the VM compiler.
83+
* \brief The virtual machine executable emitted by the VM compiler.
8484
*
8585
* The executable contains information (e.g. data in different memory regions)
8686
* to run in a virtual machine.
8787
*/
88-
class Executable : public runtime::ModuleNode {
88+
class VMExecutable : public runtime::ModuleNode {
8989
public:
9090
/*! \brief Get the property of the runtime module .*/
9191
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };
@@ -120,18 +120,18 @@ class Executable : public runtime::ModuleNode {
120120
*/
121121
String AsPython() const;
122122
/*!
123-
* \brief Write the Executable to the binary stream in serialized form.
123+
* \brief Write the VMExecutable to the binary stream in serialized form.
124124
* \param stream The binary stream to save the executable to.
125125
*/
126126
void SaveToBinary(dmlc::Stream* stream) final;
127127
/*!
128-
* \brief Load Executable from the binary stream in serialized form.
128+
* \brief Load VMExecutable from the binary stream in serialized form.
129129
* \param stream The binary stream that load the executable from.
130130
* \return The loaded executable, in the form of a `runtime::Module`.
131131
*/
132132
static Module LoadFromBinary(void* stream);
133133
/*!
134-
* \brief Write the Executable to the provided path as a file containing its serialized content.
134+
* \brief Write the VMExecutable to the provided path as a file containing its serialized content.
135135
* \param file_name The name of the file to write the serialized data to.
136136
* \param format The target format of the saved file.
137137
*/
@@ -140,10 +140,10 @@ class Executable : public runtime::ModuleNode {
140140
Module VMLoadExecutable() const;
141141
/*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */
142142
Module VMProfilerLoadExecutable() const;
143-
/*! \brief Check if the Executable contains a specific function. */
143+
/*! \brief Check if the VMExecutable contains a specific function. */
144144
bool HasFunction(const String& name) const;
145145
/*!
146-
* \brief Load Executable from the file.
146+
* \brief Load VMExecutable from the file.
147147
* \param file_name The path of the file that load the executable from.
148148
* \return The loaded executable, in the form of a `runtime::Module`.
149149
*/
@@ -160,15 +160,15 @@ class Executable : public runtime::ModuleNode {
160160
/*! \brief The byte data of instruction. */
161161
std::vector<ExecWord> instr_data;
162162

163-
virtual ~Executable() {}
163+
virtual ~VMExecutable() {}
164164

165-
TVM_MODULE_VTABLE_BEGIN("relax.Executable");
166-
TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats);
167-
TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText);
168-
TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython);
169-
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
170-
TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &Executable::VMProfilerLoadExecutable);
171-
TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction);
165+
TVM_MODULE_VTABLE_BEGIN("relax.VMExecutable");
166+
TVM_MODULE_VTABLE_ENTRY("stats", &VMExecutable::Stats);
167+
TVM_MODULE_VTABLE_ENTRY("as_text", &VMExecutable::AsText);
168+
TVM_MODULE_VTABLE_ENTRY("as_python", &VMExecutable::AsPython);
169+
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &VMExecutable::VMLoadExecutable);
170+
TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &VMExecutable::VMProfilerLoadExecutable);
171+
TVM_MODULE_VTABLE_ENTRY("has_function", &VMExecutable::HasFunction);
172172
TVM_MODULE_VTABLE_END();
173173

174174
private:

include/tvm/runtime/relax_vm/vm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class VirtualMachine : public runtime::ModuleNode {
143143
* \brief Load the executable for the virtual machine.
144144
* \param exec The executable.
145145
*/
146-
virtual void LoadExecutable(ObjectPtr<Executable> exec) = 0;
146+
virtual void LoadExecutable(ObjectPtr<VMExecutable> exec) = 0;
147147
/*!
148148
* \brief Get global function in the VM.
149149
* \param func_name The name of the function.

python/tvm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from . import te
5656

5757
# tvm.driver
58-
from .driver import build
58+
from .driver import build, compile
5959

6060
# others
6161
from . import arith

python/tvm/contrib/hexagon/session.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
from typing import Union
2525

2626
import tvm
27-
from tvm import relax
27+
import tvm.contrib.hexagon as hexagon
2828
from tvm import rpc as _rpc
29+
from tvm import runtime
2930
from tvm.contrib import utils
30-
import tvm.contrib.hexagon as hexagon
31-
from .tools import export_module, HEXAGON_SIMULATOR_NAME
31+
32+
from .tools import HEXAGON_SIMULATOR_NAME, export_module
3233

3334

3435
class Session:
@@ -202,26 +203,26 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
202203
return self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path))
203204

204205
def get_executor_from_factory(
205-
self, module: Union[ExecutorFactoryModule, relax.Executable, str], hexagon_arch: str = "v68"
206+
self, module: Union[runtime.executable, str], hexagon_arch: str = "v68"
206207
):
207208
"""Create a local GraphModule which consumes a remote libmod.
208209
209210
Parameters
210211
----------
211212
212-
module : Union[relax.Executable]
213+
module : Union[runtime.Executable, str]
213214
214215
The module to upload to the remote
215216
session and load.
216217
hexagon_arch : str
217218
The hexagon arch to be used
218219
"""
219-
if isinstance(module, (relax.Executable, str)):
220+
if isinstance(module, (runtime.Executable, str)):
220221
return self._relax_vm_executable_executor(module, hexagon_arch=hexagon_arch)
221222

222223
raise TypeError(f"Unsupported executor type: {type(module)}")
223224

224-
def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]):
225+
def _set_device_type(self, module: Union[str, pathlib.Path]):
225226
"""Set session device type(hexagon, cpu) based on target in module.
226227
227228
Parameters
@@ -244,40 +245,41 @@ def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactory
244245
self._requires_cpu_device = False
245246

246247
def _relax_vm_executable_executor(
247-
self, vm_exec: Union[relax.Executable, str], hexagon_arch: str
248+
self, executable: Union[runtime.Executable, str], hexagon_arch: str
248249
):
249250
"""Create a local TVM module which consumes a remote vm executable.
250251
251-
Paramters
252-
---------
252+
Parameters
253+
----------
253254
254-
vm_exec : relax.Executable
255-
The Relax VM Executable to upload to the remote and load. This will typically be the
256-
output of `relax.build` or the path to an already built and exported shared library
255+
executable : runtime.Executable
256+
The Executable to upload to the remote and load. This will typically be the
257+
output of `tvm.compile` or the path to an already built and exported shared library
257258
hexagon_arch : str
258259
The hexagon arch to be used
260+
259261
Returns
260262
-------
261263
TVMModule :
262264
TVM module object
263265
"""
264266
assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use"
265267

266-
if isinstance(vm_exec, relax.Executable):
268+
if isinstance(executable, runtime.Executable):
267269
temp_dir = utils.tempdir()
268270
path_exec = temp_dir.relpath("exec.so")
269271

270-
vm_exec.mod.export_library(
272+
executable.export_library(
271273
path_exec,
272274
fcompile=hexagon.create_aot_shared,
273275
hexagon_arch=hexagon_arch,
274276
)
275277

276278
path = self.upload(path_exec, "exec.so")
277-
elif isinstance(vm_exec, str):
278-
path_exec = vm_exec
279+
elif isinstance(executable, str):
280+
path_exec = executable
279281
else:
280-
raise TypeError(f"Unsupported executor type: {type(vm_exec)}")
282+
raise TypeError(f"Unsupported executor type: {type(executable)}")
281283

282284
path = self.upload(path_exec, "exec.so")
283285
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))

python/tvm/driver/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=redefined-builtin
18+
1719
"""Namespace for driver APIs"""
18-
from .build_module import build
20+
from .build_module import build, compile

python/tvm/driver/build_module.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,95 @@
1717

1818
# pylint: disable=invalid-name
1919
"""The build utils in python."""
20-
from typing import Union, Optional
20+
import warnings
21+
from typing import Callable, Optional, Union
22+
2123
import tvm
22-
from tvm.tir import PrimFunc
2324
from tvm.ir.module import IRModule
25+
from tvm.runtime import Executable
2426
from tvm.target import Target
27+
from tvm.tir import PrimFunc
2528

2629

2730
def build(
2831
mod: Union[PrimFunc, IRModule],
2932
target: Optional[Union[str, Target]] = None,
30-
pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir",
33+
pipeline: Optional[Union[str, tvm.transform.Pass]] = "default",
3134
):
35+
"""
36+
Build a function with a signature, generating code for devices
37+
coupled with target information.
38+
39+
This function is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.
40+
41+
Parameters
42+
----------
43+
mod : Union[PrimFunc, IRModule]
44+
The input to be built.
45+
target : Optional[Union[str, Target]]
46+
The target for compilation.
47+
pipeline : Optional[Union[str, tvm.transform.Pass]]
48+
The pipeline to use for compilation.
49+
50+
Returns
51+
-------
52+
tvm.runtime.Module
53+
A module combining both host and device code.
54+
"""
55+
warnings.warn(
56+
"build is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.",
57+
DeprecationWarning,
58+
)
3259
return tvm.tir.build(mod, target, pipeline)
60+
61+
62+
def _contains_relax(mod: Union[PrimFunc, IRModule]) -> bool:
63+
if isinstance(mod, PrimFunc):
64+
return False
65+
if isinstance(mod, IRModule):
66+
return any(isinstance(func, tvm.relax.Function) for _, func in mod.functions_items())
67+
68+
raise ValueError(f"Function input must be a PrimFunc or IRModule, but got {type(mod)}")
69+
70+
71+
def compile( # pylint: disable=redefined-builtin
72+
mod: Union[PrimFunc, IRModule],
73+
target: Optional[Target] = None,
74+
*,
75+
relax_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default",
76+
tir_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default",
77+
) -> Executable:
78+
"""
79+
Compile an IRModule to a runtime executable.
80+
81+
This function serves as a unified entry point for compiling both TIR and Relax modules.
82+
It automatically detects the module type and routes to the appropriate build function.
83+
84+
Parameters
85+
----------
86+
mod : Union[PrimFunc, IRModule]
87+
The input module to be compiled. Can be a PrimFunc or an IRModule containing
88+
TIR or Relax functions.
89+
target : Optional[Target]
90+
The target platform to compile for.
91+
relax_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
92+
The compilation pipeline to use for Relax functions.
93+
Only used if the module contains Relax functions.
94+
tir_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
95+
The compilation pipeline to use for TIR functions.
96+
97+
Returns
98+
-------
99+
Executable
100+
A runtime executable that can be loaded and executed.
101+
"""
102+
# TODO(tvm-team): combine two path into unified one
103+
if _contains_relax(mod):
104+
return tvm.relax.build(
105+
mod,
106+
target,
107+
relax_pipeline=relax_pipeline,
108+
tir_pipeline=tir_pipeline,
109+
)
110+
lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
111+
return Executable(lib)

python/tvm/meta_schedule/relax_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def compile_relax(
382382
target: Union[Target, str],
383383
params: Optional[Dict[str, NDArray]],
384384
enable_warning: bool = False,
385-
) -> "relax.Executable":
385+
) -> "relax.VMExecutable":
386386
"""Compile a relax program with a MetaSchedule database.
387387
388388
Parameters
@@ -401,8 +401,8 @@ def compile_relax(
401401
402402
Returns
403403
-------
404-
lib : relax.Executable
405-
The built runtime module or vm Executable for the given relax workload.
404+
lib : relax.VMExecutable
405+
The built runtime module or vm VMExecutable for the given relax workload.
406406
"""
407407
# pylint: disable=import-outside-toplevel
408408
from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase

python/tvm/meta_schedule/testing/custom_builder_runner.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,18 @@
1717
"""Customized builder and runner methods"""
1818
# pylint: disable=import-outside-toplevel
1919

20-
from typing import TYPE_CHECKING, Dict, Union, Callable
20+
from typing import Dict, Union, Callable
2121

22-
if TYPE_CHECKING:
23-
import numpy as np # type: ignore
24-
from tvm.ir import IRModule
25-
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
26-
from tvm.runtime import Device, Module, NDArray
27-
from tvm.target import Target
22+
import numpy as np # type: ignore
23+
from tvm.meta_schedule.runner import RPCConfig
24+
from tvm.runtime import Module, Executable
2825

2926

3027
def run_module_via_rpc(
31-
rpc_config: "RPCConfig",
32-
lib: Union["Module", "Executable"],
28+
rpc_config: RPCConfig,
29+
lib: Union[Module, Executable],
3330
dev_type: str,
34-
args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
31+
args: Union[Dict[int, np.ndarray], Dict[str, np.ndarray]],
3532
continuation: Callable,
3633
):
3734
"""Execute a tvm.runtime.Module on RPC remote"""

0 commit comments

Comments
 (0)