Skip to content

Commit 2cab720

Browse files
committed
Add entry point
1 parent 247c54b commit 2cab720

File tree

11 files changed

+327
-3
lines changed

11 files changed

+327
-3
lines changed

include/tvm/script/printer.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_SCRIPT_PRINTER_H_
20+
#define TVM_SCRIPT_PRINTER_H_
21+
22+
#include <tvm/node/node.h>
23+
#include <tvm/node/object_path.h>
24+
25+
namespace tvm {
26+
namespace script {
27+
28+
/*!
29+
* \brief Print IR graph as TVMScript code
30+
*
31+
* \param root_node The root node to print.
32+
* \param ir_name The dispatch token of the target IR, e.g., "tir", "relax".
33+
* \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}.
34+
* \param indent_spaces Number of spaces used for indentation
35+
* \param print_line_numbers Whether to print line numbers
36+
* \param num_context_lines Number of context lines to print around the underlined text
37+
* \param path_to_underline Object path to be underlined
38+
*
39+
* \return the TVMScript code as string.
40+
*/
41+
String AsScript( //
42+
const ObjectRef& root_node, //
43+
String ir_name, //
44+
Map<String, String> ir_prefix, //
45+
int indent_spaces = 4, //
46+
bool print_line_numbers = false, //
47+
int num_context_lines = -1, //
48+
Optional<ObjectPath> path_to_underline = NullOpt //
49+
);
50+
51+
} // namespace script
52+
} // namespace tvm
53+
54+
#endif // TVM_SCRIPT_PRINTER_H_

include/tvm/script/printer/doc.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class ExprDoc : public Doc {
125125
ExprDoc() = default;
126126

127127
public:
128+
/*!
129+
* \brief Create a doc representing index access on the current ExprDoc
130+
* \param indices The indices to access.
131+
*/
132+
ExprDoc operator[](Array<Doc> indices) const;
133+
128134
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
129135
};
130136

include/tvm/script/printer/ir_docsifier.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,47 @@ class IRDocsifier : public ObjectRef {
182182
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
183183
};
184184

185+
/*!
186+
* \brief A wrapper object to provide injection point for printer of each IR.
187+
*
188+
* For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer
189+
* and be dispatched to the corresponding function first. This provides an injection point for
190+
* each IR's printer implemention to add specialized logic, for example, pushing a special
191+
* Frame to the IRDocsifier before doing any IR->Doc transformation.
192+
*
193+
* \code
194+
* TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
195+
* .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
196+
* const ObjectRef& root_node = obj.Get()->root_node;
197+
* // For example, relax printer can create a Frame specialized to Relax here
198+
* RelaxGeneralFrame frame;
199+
* auto ctx = p->WithFrame(frame);
200+
* // More specialized logic for your IR.
201+
* return p->AsDoc<Doc>(MakeTraced(root_node));
202+
* });
203+
* \endcode
204+
*/
205+
class RootNodeContainerNode : public Object {
206+
public:
207+
/*! \brief The root node to print. */
208+
ObjectRef root_node;
209+
210+
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); }
211+
212+
static constexpr const char* _type_key = "script.printer.RootNodeContainer";
213+
TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object);
214+
};
215+
216+
class RootNodeContainer : public ObjectRef {
217+
public:
218+
/*!
219+
* \brief Constructor of RootNodeContainer.
220+
* \param root_node The root node to print.
221+
* */
222+
explicit RootNodeContainer(ObjectRef root_node);
223+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode);
224+
};
225+
185226
} // namespace printer
186227
} // namespace script
187228
} // namespace tvm

python/tvm/script/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from . import tir
2020

2121
from .parser import ir_module, from_source
22+
from .as_script import as_script

python/tvm/script/as_script.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
This file contains the entry point of TVMScript Unified Printer.
19+
"""
20+
21+
from typing import Dict, Optional
22+
23+
from tvm.runtime.object_path import ObjectPath
24+
25+
from . import _ffi_api
26+
27+
28+
def as_script(
29+
root_node,
30+
ir_name: str,
31+
ir_prefix: Dict[str, str],
32+
indent_spaces: int = 4,
33+
print_line_numbers: bool = False,
34+
num_context_lines: int = -1,
35+
path_to_underline: Optional[ObjectPath] = None,
36+
) -> str:
37+
return _ffi_api.AsScript(
38+
root_node,
39+
ir_name,
40+
ir_prefix,
41+
indent_spaces,
42+
print_line_numbers,
43+
num_context_lines,
44+
path_to_underline,
45+
)

python/tvm/script/printer/ir_docsifier.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@ def _ensure_cleanup_function_registered():
5959
_CLEANUP_REGISTERED = True
6060

6161

62+
@register_object("script.printer.RootNodeContainer")
63+
class RootNodeContainer(Object):
64+
"""
65+
A wrapper object to provide injection point for printer of each IR.
66+
67+
This class shouldn't be used directly. `IRDocsifier.set_root_dispatch`
68+
should be used instead.
69+
"""
70+
71+
root_node: Object
72+
73+
def __init__(self, root_node: Object):
74+
self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member
75+
76+
6277
@register_object("script.printer.IRDocsifier")
6378
class IRDocsifier(Object):
6479
"""
@@ -91,7 +106,7 @@ def __init__(self, ir_prefix: Dict[str, str]):
91106
def set_dispatch(
92107
cls,
93108
node_type: Type[_TObject],
94-
dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc],
109+
dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc],
95110
dispatch_token: str = "",
96111
) -> None:
97112
"""
@@ -101,7 +116,7 @@ def set_dispatch(
101116
----------
102117
node_type : Type[_TObject]
103118
The type of object to dispatch on.
104-
dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
119+
dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc]
105120
The dispatch function. It's called to transform IR node object to Doc.
106121
dispatch_token : str
107122
Function will only be called when this dispatch_token is the same as the one
@@ -119,6 +134,38 @@ def set_dispatch(
119134
)
120135
_REGISTERED_TYPES.add((dispatch_token, type_index))
121136

137+
@classmethod
138+
def set_root_dispatch(
139+
cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc]
140+
) -> None:
141+
"""
142+
Set the root dispatch function for an IR.
143+
144+
The root dispatch function will be called with the root node of an IR graph
145+
that's being transformed to Doc. This provides an injection point for
146+
each IR's printer implemention to add specialized logic, for example,
147+
pushing a special Frame to the IRDocsifier before doing actual IR->Doc
148+
transformation.
149+
150+
The simplest root dispatch function is
151+
```
152+
def f(obj, ir_docsifier)
153+
return ir_docsifier.as_doc(obj, ObjectPath.root())
154+
```
155+
156+
Parameters
157+
----------
158+
root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
159+
The root dispatch function. It's called with the root node to be printed.
160+
dispatch_token : str
161+
The dispatch token of the IR that root_dispatch_funnction applies to.
162+
"""
163+
164+
def dispatch_function(obj: RootNodeContainer, _, ir_docsifier):
165+
return root_dispatch_function(obj.root_node, ir_docsifier)
166+
167+
cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token)
168+
122169
def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
123170
"""
124171
Transform the input object into Doc.

src/script/printer.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/runtime/registry.h>
21+
#include <tvm/script/printer.h>
22+
#include <tvm/script/printer/doc.h>
23+
#include <tvm/script/printer/doc_printer.h>
24+
#include <tvm/script/printer/frame.h>
25+
#include <tvm/script/printer/ir_docsifier.h>
26+
27+
namespace tvm {
28+
namespace script {
29+
30+
using namespace printer;
31+
32+
String AsScript( //
33+
const ObjectRef& root_node, //
34+
String ir_name, //
35+
Map<String, String> ir_prefix, //
36+
int indent_spaces, //
37+
bool print_line_numbers, //
38+
int num_context_lines, //
39+
Optional<ObjectPath> path_to_underline //
40+
) {
41+
IRDocsifier ir_docsifier(ir_prefix);
42+
43+
auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);
44+
45+
Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));
46+
47+
return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
48+
path_to_underline);
49+
}
50+
51+
TVM_REGISTER_GLOBAL("script.AsScript").set_body_typed(&AsScript);
52+
53+
} // namespace script
54+
} // namespace tvm

src/script/printer/doc.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
4040
return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
4141
}
4242

43+
ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; }
44+
4345
StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
4446
ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
4547
n->stmts = stmts;

src/script/printer/ir_docsifier.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919
#include <tvm/runtime/container/base.h>
20+
#include <tvm/runtime/logging.h>
2021
#include <tvm/runtime/registry.h>
2122
#include <tvm/script/printer/ir_docsifier.h>
2223
#include <tvm/script/printer/traced_object.h>
@@ -42,6 +43,31 @@ IRDocsifier::FType& IRDocsifier::vtable() {
4243
return inst;
4344
}
4445

46+
RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
47+
auto n = make_object<RootNodeContainerNode>();
48+
n->root_node = std::move(root_node);
49+
data_ = std::move(n);
50+
}
51+
52+
// Add a default dispatch for the RootNodeContainer to throw error.
53+
// To add implementation for a new IR, RootNodeContainer needs to be
54+
// registered under the dispatch token of that IR, like:
55+
// \code
56+
// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
57+
// .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
58+
// const ObjectRef& root_node = obj.Get()->root_node;
59+
// \\ More specialized logic for your IR.
60+
// return p->AsDoc<Doc>(MakeTraced(root_node));
61+
// });
62+
// \endcode
63+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
64+
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
65+
String top_dispatch_token = p->dispatch_tokens.back();
66+
ICHECK_NE(top_dispatch_token, "");
67+
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
68+
throw;
69+
});
70+
4571
TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
4672
TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map<String, String> ir_prefix) {
4773
return IRDocsifier(ir_prefix);
@@ -71,6 +97,12 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
7197
.set_body_typed([](String token, uint64_t type_index) {
7298
IRDocsifier::vtable().remove_dispatch(token, type_index);
7399
});
100+
101+
TVM_REGISTER_NODE_TYPE(RootNodeContainerNode);
102+
TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) {
103+
return RootNodeContainer(root_node);
104+
});
105+
74106
} // namespace printer
75107
} // namespace script
76108
} // namespace tvm
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import pytest
18+
19+
from tvm.error import TVMError
20+
from tvm.script import as_script
21+
from tvm.tir import FloatImm
22+
23+
24+
def test_as_script_unknown_ir():
25+
ir_node = FloatImm("float32", 1.0)
26+
27+
with pytest.raises(TVMError) as e:
28+
as_script(ir_node, "test_xyz", {})
29+
30+
assert "test_xyz" in str(e.value)

0 commit comments

Comments
 (0)