Skip to content

Commit f215a41

Browse files
authored
[Unity][NN] Use Linear name for nn.op.permute_dims (#16303)
The `relax::op::linear` is implemented as `permute_dims`, followed by `matmul`. In this case, readability can be improved by naming the weights.
1 parent fe5f616 commit f215a41

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str = "broadcast_to"
577577
return wrap_nested(_op.broadcast_to(x._expr, shape), name)
578578

579579

580-
def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permute_dims") -> Tensor:
580+
def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = None) -> Tensor:
581581
"""Permutes the dimensions of an array.
582582
583583
Parameters
@@ -596,6 +596,13 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permu
596596
result : Tensor
597597
The transposed result.
598598
"""
599+
if name is None:
600+
x_name = getattr(getattr(x, "_expr", None), "name_hint", None)
601+
if x_name is not None and "linear" in x_name:
602+
name = x_name.replace("linear", "matmul")
603+
else:
604+
name = "permute_dims"
605+
599606
return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
600607

601608

tests/python/relax/test_frontend_nn_packing.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
from tvm.script import relax as R
2222

2323

24-
def main():
24+
def _iter_binding_names(mod):
25+
"""Helper function to compare the names of relax variables"""
26+
for block in mod["forward"].body.blocks:
27+
for binding in block.bindings:
28+
yield binding.var.name_hint
29+
30+
31+
def test_nn_export_to_relax():
2532
class TestModule(nn.Module):
2633
def __init__(self, in_features: int, out_features: int):
2734
super().__init__()
@@ -35,39 +42,28 @@ def forward(self, x: nn.Tensor):
3542
x2 = self.linear_2(x)
3643
return x1 + x2
3744

38-
# pylint: disable=line-too-long
3945
@I.ir_module
40-
class ExpectedModule: # pylint: disable=too-few-public-methods
46+
class ExpectedModule:
4147
@R.function
4248
def forward(
4349
x: R.Tensor((1, 10), dtype="float32"),
4450
packed_params: R.Tuple(
4551
R.Tensor((20, 10), dtype="float32"), R.Tensor((20, 10), dtype="float32")
4652
),
47-
) -> R.Tensor((1, 20), dtype="float32"):
48-
R.func_attr({"num_input": 1}) # type: ignore[attr-defined]
49-
with R.dataflow(): # type: ignore[attr-defined]
50-
linear_1_weight: R.Tensor((20, 10), dtype="float32") = packed_params[0] # type: ignore[valid-type]
51-
linear_2_weight: R.Tensor((20, 10), dtype="float32") = packed_params[1] # type: ignore[valid-type]
52-
permute_dims: R.Tensor((10, 20), dtype="float32") = R.permute_dims( # type: ignore[attr-defined,valid-type]
53-
linear_1_weight, axes=None
54-
)
55-
matmul: R.Tensor((1, 20), dtype="float32") = R.matmul( # type: ignore[attr-defined,valid-type]
56-
x, permute_dims, out_dtype="void"
57-
)
58-
permute_dims1: R.Tensor((10, 20), dtype="float32") = R.permute_dims( # type: ignore[attr-defined,valid-type]
59-
linear_2_weight, axes=None
60-
)
61-
matmul1: R.Tensor((1, 20), dtype="float32") = R.matmul( # type: ignore[attr-defined,valid-type]
62-
x, permute_dims1, out_dtype="void"
63-
)
64-
add: R.Tensor((1, 20), dtype="float32") = R.add(matmul, matmul1) # type: ignore[attr-defined,valid-type]
65-
gv: R.Tensor((1, 20), dtype="float32") = add # type: ignore[attr-defined,valid-type]
66-
R.output(gv) # type: ignore[attr-defined,valid-type]
53+
):
54+
R.func_attr({"num_input": 1})
55+
with R.dataflow():
56+
linear_1_weight = packed_params[0]
57+
linear_2_weight = packed_params[1]
58+
matmul_1_weight = R.permute_dims(linear_1_weight)
59+
matmul = R.matmul(x, matmul_1_weight)
60+
matmul_2_weight = R.permute_dims(linear_2_weight)
61+
matmul1 = R.matmul(x, matmul_2_weight)
62+
add = R.add(matmul, matmul1)
63+
gv = add
64+
R.output(gv)
6765
return gv
6866

69-
# pylint: enable=line-too-long
70-
7167
model = TestModule(10, 20)
7268
mod, _ = model.export_tvm(
7369
spec={
@@ -82,6 +78,9 @@ def forward(
8278
)
8379
tvm.ir.assert_structural_equal(mod, ExpectedModule)
8480

81+
for name, expected_name in zip(_iter_binding_names(mod), _iter_binding_names(ExpectedModule)):
82+
assert name == expected_name
83+
8584

8685
if __name__ == "__main__":
87-
main()
86+
tvm.testing.main()

0 commit comments

Comments
 (0)