Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _find_parser_def(self):
def get_macro_def(self):
ast_module = self.source.as_ast()
for decl in ast_module.body:
if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__:
if isinstance(decl, doc.FunctionDef) and decl.name == self.func.__name__:
return decl
raise RuntimeError(f"cannot find macro definition for {self.__name__}")
raise RuntimeError(f"cannot find macro definition for {self.func.__name__}")

def __call__(self, *args, **kwargs):
param_binding = inspect.signature(self.func).bind(*args, **kwargs)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def macro(*args, hygienic: bool = True) -> _Callable:
def _decorator(func: _Callable) -> ScriptMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = RelaxMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj

def wrapper(*args, **kwargs):
return obj(*args, **kwargs)

return wrapper

if len(args) == 0:
return _decorator
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def _decorator(func: Callable) -> TIRMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj

def wrapper(*args, **kwargs):
return obj(*args, **kwargs)

return wrapper

if len(args) == 0:
return _decorator
Expand Down
42 changes: 38 additions & 4 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def evaluate0():
def func1():
T.evaluate(0)

assert func1.hygienic

@T.prim_func(private=True)
def use1():
func1()
Expand All @@ -129,8 +127,6 @@ def use1():
def func2():
T.evaluate(0)

assert func2.hygienic

@T.prim_func(private=True)
def use2():
func2()
Expand Down Expand Up @@ -212,6 +208,44 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32"
tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)


def test_tir_macro_in_class():
class Object:
def __init__(self, x: T.Buffer):
self.local_x = T.alloc_buffer(x.shape, x.dtype)

@T.macro
def load(self, x: T.Buffer):
N, M = T.meta_var(self.local_x.shape)
for i, j in T.grid(N, M):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
self.local_x[vi, vj] = x[vi, vj]

@T.prim_func(private=True)
def func_w_macro(a: T.handle):
A = T.match_buffer(a, [128, 128])
o1 = T.meta_var(Object(A))
o1.load(A)
o2 = T.meta_var(Object(A))
o2.load(o1.local_x)

@T.prim_func(private=True)
def func_no_macro(a: T.handle):
A = T.match_buffer(a, [128, 128])
local_a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
local_a[vi, vj] = A[vi, vj]
local_b = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
local_b[vi, vj] = local_a[vi, vj]

tvm.ir.assert_structural_equal(func_no_macro, func_w_macro)


def test_tir_starred_expression():
dims = (128, 128)

Expand Down
Loading