Skip to content

Commit 8226e16

Browse files
jinhongyiiLunderberg
authored andcommitted
[TVMScript] Expose IRModule::attrs as I.module_attrs
This is an upstreaming of the non-relax portions of #14132, including a unit test specically to validate `I.module_attrs`.
1 parent bfeafa2 commit 8226e16

File tree

14 files changed

+100
-12
lines changed

14 files changed

+100
-12
lines changed

include/tvm/script/ir_builder/base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
237237
* \sa tvm::support::With
238238
*/
239239
static IRBuilder Current();
240+
/*! \brief See if the current thread-local scope has an IRBuilder. */
241+
static bool IsInScope();
240242
/*!
241243
* \brief Give a string name to the `obj`
242244
* \tparam TObjectRef The type of the object to name.

include/tvm/script/ir_builder/ir/frame.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
4545
* \note Only defined functions are in the map, while declared functions are not included.
4646
*/
4747
Map<GlobalVar, BaseFunc> functions;
48+
/*! \brief IRModule's attributes. */
49+
Map<String, ObjectRef> attrs;
4850

4951
void VisitAttrs(tvm::AttrVisitor* v) {
5052
IRBuilderFrameNode::VisitAttrs(v);
5153
v->Visit("global_vars", &global_var_map);
5254
v->Visit("functions", &functions);
55+
v->Visit("attrs", &attrs);
5356
}
5457

5558
static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";

python/tvm/ir/module.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
3737
Map of global var to BaseFunc
3838
"""
3939

40-
def __init__(self, functions=None, type_definitions=None):
40+
def __init__(self, functions=None, type_definitions=None, attrs=None):
4141
if functions is None:
4242
functions = {}
4343
elif isinstance(functions, dict):
@@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None):
6060
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
6161
mapped_type_defs[k] = v
6262
type_definitions = mapped_type_defs
63-
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
63+
64+
attrs = None if not attrs else attrs
65+
if attrs is not None:
66+
attrs = ast.literal_eval(str(attrs))
67+
attrs = tvm.ir.make_node("DictAttrs", **attrs)
68+
self.__init_handle_by_constructor__(
69+
_ffi_api.IRModule,
70+
functions,
71+
type_definitions,
72+
attrs,
73+
)
6474

6575
def __setitem__(self, var, val):
6676
"""Add a mapping to the module.

python/tvm/script/ir_builder/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,17 @@ def current() -> "IRBuilder":
138138
"""
139139
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member
140140

141+
@staticmethod
142+
def is_in_scope() -> bool:
143+
"""See if the current thread-local scope has an IRBuilder.
144+
145+
Returns
146+
-------
147+
bool
148+
Whether the current thread-local scope has an IRBuilder
149+
"""
150+
return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member
151+
141152
def get(self) -> _Object:
142153
"""Get the constructed IR."""
143154
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member

python/tvm/script/ir_builder/ir/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,9 @@
1616
# under the License.
1717
"""Package tvm.script.ir_builder.ir"""
1818
from .frame import IRModuleFrame
19-
from .ir import decl_function, def_function, ir_module
19+
from .ir import (
20+
decl_function,
21+
def_function,
22+
ir_module,
23+
module_attrs,
24+
)

python/tvm/script/ir_builder/ir/ir.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
# under the License.
1717
"""Package tvm.script.ir_builder.ir.ir"""
1818

19+
from typing import Dict, List
20+
21+
from tvm.runtime import Object as tvm_Object
22+
1923
from tvm.ir import BaseFunc, GlobalVar
2024

2125
from . import _ffi_api
@@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None:
6771
The given function implementation
6872
"""
6973
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member
74+
75+
76+
def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
77+
"""Specify the attrs of the ir_module frame.
78+
Parameters
79+
----------
80+
attrs: Dict[str, Object]
81+
The module attrs.
82+
"""
83+
return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member

python/tvm/script/parser/ir/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""The ir module parser"""
18-
18+
from ...ir_builder.ir import * # pylint: disable=redefined-builtin
1919
from . import parser as _parser
2020
from .entry import ir_module
2121

22-
__all__ = ["ir_module"]
22+
__all__ = ["ir_module", "module_attrs"]

python/tvm/script/parser/ir/parser.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
3535

3636
with self.var_table.with_frame():
3737
with I.ir_module():
38+
with self.with_dispatch_token("ir"):
39+
for stmt in node.body:
40+
if not isinstance(stmt, doc.FunctionDef):
41+
self.visit(stmt)
3842
for stmt in node.body:
3943
if isinstance(stmt, doc.FunctionDef):
4044
self.visit_tvm_declare_function(stmt)
4145
with self.with_dispatch_token("ir"):
42-
self.visit_body(node.body)
46+
for stmt in node.body:
47+
if isinstance(stmt, doc.FunctionDef):
48+
self.visit(stmt)
4349

4450

4551
@dispatch.register(token="ir", type_name="Assign")
@@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
5763

5864

5965
@dispatch.register(token="ir", type_name="Expr")
60-
def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
66+
def _visit_expr(self: Parser, node: doc.Expr) -> None:
6167
"""The expression visiting method for ir module.
6268
6369
Parameters
@@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
6874
node : doc.ClassDef
6975
The doc AST expression node.
7076
"""
77+
self.eval_expr(node.value)
7178

7279

7380
@dispatch.register(token="default", type_name="Assign")

src/ir/module.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
382382
TVM_REGISTER_NODE_TYPE(IRModuleNode);
383383

384384
TVM_REGISTER_GLOBAL("ir.IRModule")
385-
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
386-
tvm::Map<GlobalTypeVar, TypeData> types) {
387-
return IRModule(funcs, types, {});
388-
});
385+
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
386+
tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });
389387

390388
TVM_REGISTER_GLOBAL("ir.Module_Add")
391389
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {

src/script/ir_builder/base.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
7777
return stack->back();
7878
}
7979

80+
bool IRBuilder::IsInScope() {
81+
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
82+
return !stack->empty();
83+
}
84+
8085
namespace details {
8186

8287
Namer::FType& Namer::vtable() {
@@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
106111
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
107112
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
108113
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
114+
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
109115
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
110116
.set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
111117
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);

0 commit comments

Comments
 (0)