Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str = "broadcast_to"
return wrap_nested(_op.broadcast_to(x._expr, shape), name)


def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permute_dims") -> Tensor:
def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = None) -> Tensor:
"""Permutes the dimensions of an array.

Parameters
Expand All @@ -596,6 +596,13 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permu
result : Tensor
The transposed result.
"""
if name is None:
x_name = getattr(getattr(x, "_expr", None), "name_hint", None)
if x_name is not None and "linear" in x_name:
name = x_name.replace("linear", "matmul")
else:
name = "permute_dims"

return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)


Expand Down
51 changes: 25 additions & 26 deletions tests/python/relax/test_frontend_nn_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from tvm.script import relax as R


def main():
def _iter_binding_names(mod):
"""Helper function to compare the names of relax variables"""
for block in mod["forward"].body.blocks:
for binding in block.bindings:
yield binding.var.name_hint


def test_nn_export_to_relax():
class TestModule(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
Expand All @@ -35,39 +42,28 @@ def forward(self, x: nn.Tensor):
x2 = self.linear_2(x)
return x1 + x2

# pylint: disable=line-too-long
@I.ir_module
class ExpectedModule: # pylint: disable=too-few-public-methods
class ExpectedModule:
@R.function
def forward(
x: R.Tensor((1, 10), dtype="float32"),
packed_params: R.Tuple(
R.Tensor((20, 10), dtype="float32"), R.Tensor((20, 10), dtype="float32")
),
) -> R.Tensor((1, 20), dtype="float32"):
R.func_attr({"num_input": 1}) # type: ignore[attr-defined]
with R.dataflow(): # type: ignore[attr-defined]
linear_1_weight: R.Tensor((20, 10), dtype="float32") = packed_params[0] # type: ignore[valid-type]
linear_2_weight: R.Tensor((20, 10), dtype="float32") = packed_params[1] # type: ignore[valid-type]
permute_dims: R.Tensor((10, 20), dtype="float32") = R.permute_dims( # type: ignore[attr-defined,valid-type]
linear_1_weight, axes=None
)
matmul: R.Tensor((1, 20), dtype="float32") = R.matmul( # type: ignore[attr-defined,valid-type]
x, permute_dims, out_dtype="void"
)
permute_dims1: R.Tensor((10, 20), dtype="float32") = R.permute_dims( # type: ignore[attr-defined,valid-type]
linear_2_weight, axes=None
)
matmul1: R.Tensor((1, 20), dtype="float32") = R.matmul( # type: ignore[attr-defined,valid-type]
x, permute_dims1, out_dtype="void"
)
add: R.Tensor((1, 20), dtype="float32") = R.add(matmul, matmul1) # type: ignore[attr-defined,valid-type]
gv: R.Tensor((1, 20), dtype="float32") = add # type: ignore[attr-defined,valid-type]
R.output(gv) # type: ignore[attr-defined,valid-type]
):
R.func_attr({"num_input": 1})
with R.dataflow():
linear_1_weight = packed_params[0]
linear_2_weight = packed_params[1]
matmul_1_weight = R.permute_dims(linear_1_weight)
matmul = R.matmul(x, matmul_1_weight)
matmul_2_weight = R.permute_dims(linear_2_weight)
matmul1 = R.matmul(x, matmul_2_weight)
add = R.add(matmul, matmul1)
gv = add
R.output(gv)
return gv

# pylint: enable=line-too-long

model = TestModule(10, 20)
mod, _ = model.export_tvm(
spec={
Expand All @@ -82,6 +78,9 @@ def forward(
)
tvm.ir.assert_structural_equal(mod, ExpectedModule)

for name, expected_name in zip(_iter_binding_names(mod), _iter_binding_names(ExpectedModule)):
assert name == expected_name


if __name__ == "__main__":
main()
tvm.testing.main()