Skip to content

Commit 9b1846d

Browse files
committed
[TVMScript][UX] Introduce decorator for deprecation
This PR introduces a decorator `tvm.ir.base.deprecated`, which emits a deprecation warning if an outdated API is used, but preserves backward compatibility by still allowing the API to be used. For example, currently the preferred way of TIR buffer declaration in function signature is: ```python def example( A: T.Buffer(...), # legacy behavior is `T.Buffer[...]` ): ... ``` With this decorator, if a user writes `T.Buffer[...]`, the parser will still function properly, but emits a warning that guides the user to adopt `T.Buffer(...)` if possible. While there is no breaking change at all in this PR, we believe this is useful to help users upgrade before any breaking change eventually takes place.
1 parent 45a92df commit 9b1846d

File tree

176 files changed

+3563
-3524
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

176 files changed

+3563
-3524
lines changed

apps/pt_tvmdsoop/tests/test_as_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
5252
@tvm.script.ir_module
5353
class ModuleGPU:
5454
@T.prim_func
55-
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
55+
def main(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")) -> None:
5656
T.func_attr({"global_symbol": "main", "tir.noalias": True})
5757
for i_0 in T.thread_binding(2, thread="blockIdx.x"):
5858
for i_2 in T.thread_binding(2, thread="threadIdx.x"):

apps/pt_tvmdsoop/tests/test_boolean_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def test_tensor_boolean_operation():
8181
@as_torch
8282
@T.prim_func
8383
def negate_tvmscript(
84-
X: T.Buffer[(8, 8), "bool"],
85-
Y: T.Buffer[(8, 8), "float32"],
86-
Z: T.Buffer[(8, 8), "bool"],
87-
U: T.Buffer[(8, 8), "float32"],
84+
X: T.Buffer((8, 8), "bool"),
85+
Y: T.Buffer((8, 8), "float32"),
86+
Z: T.Buffer((8, 8), "bool"),
87+
U: T.Buffer((8, 8), "float32"),
8888
) -> None:
8989
for i, j in T.grid(8, 8):
9090
with T.block():

include/tvm/script/printer/doc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ class AssignDocNode : public StmtDocNode {
774774
/*!
775775
* \brief The right hand side of the assignment.
776776
*
777-
* If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
777+
* If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))`
778778
* */
779779
Optional<ExprDoc> rhs;
780780
/*! \brief The type annotation of this assignment. */

include/tvm/tir/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
576576
*
577577
* \code{.py}
578578
* @T.prim_func
579-
* def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
579+
* def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
580580
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
581581
* for i in T.serial(0, 16,
582582
* annotations={"software_pipeline_stage": [0, 1],
@@ -601,7 +601,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
601601
*
602602
* \code{.py}
603603
* @T.prim_func
604-
* def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
604+
* def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
605605
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
606606
* with T.block():
607607
* T.reads([A[tx, 0:16]])

python/tvm/ir/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,34 @@ def structural_hash(node, map_free_vars=False):
282282
structrual_equal
283283
"""
284284
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member
285+
286+
287+
def deprecated(
288+
method_name: str,
289+
new_method_name: str,
290+
):
291+
"""A decorator to indicate that a method is deprecated
292+
293+
Parameters
294+
----------
295+
method_name : str
296+
The name of the method to deprecate
297+
new_method_name : str
298+
The name of the new method to use instead
299+
"""
300+
import functools # pylint: disable=import-outside-toplevel
301+
import warnings # pylint: disable=import-outside-toplevel
302+
303+
def _deprecate(func):
304+
@functools.wraps(func)
305+
def _wrapper(*args, **kwargs):
306+
warnings.warn(
307+
f"{method_name} is deprecated, use {new_method_name} instead",
308+
DeprecationWarning,
309+
stacklevel=2,
310+
)
311+
return func(*args, **kwargs)
312+
313+
return _wrapper
314+
315+
return _deprecate
Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
# Licensed to the Apache Software Foundation (ASF) under one
2-
# or more contributor license agreements. See the NOTICE file
3-
# distributed with this work for additional information
4-
# regarding copyright ownership. The ASF licenses this file
5-
# to you under the Apache License, Version 2.0 (the
6-
# "License"); you may not use this file except in compliance
7-
# with the License. You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an
13-
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14-
# KIND, either express or implied. See the License for the
15-
# specific language governing permissions and limitations
16-
# under the License.
17-
18-
"""Module container of STM32 code generator."""
19-
20-
from .emitter import CodeEmitter, get_input_tensor_name, get_output_tensor_name
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Module container of STM32 code generator."""
19+
20+
from .emitter import CodeEmitter, get_input_tensor_name, get_output_tensor_name

python/tvm/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,36 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""The legacy TVM parser """
19+
from .ir.base import deprecated
20+
1921
# pylint: disable=import-outside-toplevel
2022

2123

24+
@deprecated("tvm.parser.parse", "tvm.relay.parse")
2225
def parse(*args, **kwargs):
2326
"""Deprecated, use `tvm.relay.parse` instead"""
2427
from tvm.relay import parse as _impl
2528

2629
return _impl(*args, **kwargs)
2730

2831

32+
@deprecated("tvm.parser.parse_expr", "tvm.relay.parse_expr")
2933
def parse_expr(*args, **kwargs):
3034
"""Deprecated, use `tvm.relay.parse_expr` instead"""
3135
from tvm.relay import parse_expr as _impl
3236

3337
return _impl(*args, **kwargs)
3438

3539

40+
@deprecated("tvm.parser.fromtext", "tvm.relay.fromtext")
3641
def fromtext(*args, **kwargs):
3742
"""Deprecated, use `tvm.relay.fromtext` instead"""
3843
from tvm.relay import fromtext as _impl
3944

4045
return _impl(*args, **kwargs)
4146

4247

48+
@deprecated("tvm.parser.SpanCheck", "tvm.relay.SpanCheck")
4349
def SpanCheck(*args, **kwargs):
4450
"""Deprecated, use `tvm.relay.SpanCheck` instead"""
4551
from tvm.relay import SpanCheck as _impl

python/tvm/script/parser/core/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""TVM Script Parser utils"""
18-
1918
import inspect
2019
from types import FrameType
2120
from typing import Any, Callable, Dict, List

python/tvm/script/parser/tir/entry.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919
from typing import Callable, Union
2020

21+
from tvm.ir.base import deprecated
2122
from tvm.tir import Buffer, PrimFunc
2223

2324
from ...ir_builder.tir import buffer_decl, ptr
@@ -49,7 +50,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
4950

5051
class BufferProxy:
5152
"""Buffer proxy class for constructing tir buffer.
52-
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer[].
53+
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
5354
"""
5455

5556
def __call__(
@@ -78,6 +79,7 @@ def __call__(
7879
axis_separators=axis_separators,
7980
)
8081

82+
@deprecated("T.Buffer(...)", "T.Buffer(...)")
8183
def __getitem__(self, keys) -> Buffer:
8284
if not isinstance(keys, tuple):
8385
return self(keys)
@@ -88,14 +90,15 @@ def __getitem__(self, keys) -> Buffer:
8890

8991
class PtrProxy:
9092
"""Ptr proxy class for constructing tir pointer.
91-
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
93+
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
9294
"""
9395

9496
def __call__(self, dtype, storage_scope="global"):
9597
if callable(dtype):
9698
dtype = dtype().dtype
9799
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
98100

101+
@deprecated("T.Ptr(...)", "T.Ptr(...)")
99102
def __getitem__(self, keys):
100103
if not isinstance(keys, tuple):
101104
return self(keys)

python/tvm/testing/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,13 +1932,13 @@ class object that inherits from `Exception`.
19321932
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
19331933
transform = tvm.tir.transform.Simplify()
19341934
1935-
def before(A: T.Buffer[1, "int32"]):
1935+
def before(A: T.Buffer(1, "int32")):
19361936
if True:
19371937
A[0] = 42
19381938
else:
19391939
A[0] = 5
19401940
1941-
def expected(A: T.Buffer[1, "int32"]):
1941+
def expected(A: T.Buffer(1, "int32")):
19421942
A[0] = 42
19431943
19441944
"""

0 commit comments

Comments
 (0)