Skip to content

Commit 3cc384d

Browse files
committed
[SLM] Provide consistent output features in nn.Linear
Resolve a breakage introduced in apache#16757. Prior to apache#16757, distinct TIR variables were unified if they had the same name. This commit avoids using distinct TIR variables to represent the same user input.
1 parent 80bcf4c commit 3cc384d

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

python/tvm/relax/frontend/nn/modules.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def __init__(
106106
out_dtype: Optional[str] = None,
107107
):
108108
super().__init__()
109+
110+
if isinstance(out_features, str):
111+
out_features = tir.Var(out_features, "int64")
112+
109113
self.in_features = in_features
110114
self.out_features = out_features
111115
self.out_dtype = out_dtype

tests/python/relax/test_frontend_nn_exporter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,5 +439,35 @@ def transform_params(
439439
assert_structural_equal(lifted_mod, ExpectedAfterLift)
440440

441441

442+
def test_linear_dynamic_shape():
443+
"""The weight and bias of nn.Linear have the same out_features
444+
445+
Even if dynamic, the weight/bias must be the same value.
446+
"""
447+
448+
@R.function
449+
def forward(
450+
x: R.Tensor((1, 4), dtype="float32"),
451+
_io: R.Object,
452+
weight: R.Tensor(("n", 4), dtype="float32"),
453+
bias: R.Tensor(("n",), dtype="float32"),
454+
) -> R.Tuple(R.Tensor((1, "n"), dtype="float32"), R.Tuple(R.Object)):
455+
n = T.int64()
456+
R.func_attr({"num_input": 2})
457+
with R.dataflow():
458+
permute_dims: R.Tensor((4, n), dtype="float32") = R.permute_dims(weight, axes=None)
459+
matmul: R.Tensor((1, n), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
460+
add: R.Tensor((1, n), dtype="float32") = R.add(matmul, bias)
461+
gv1: R.Tuple(R.Tensor((1, n), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
462+
R.output(gv1)
463+
return gv1
464+
465+
mod = nn.modules.Linear(in_features=4, out_features="n", bias=True)
466+
tvm_mod, _ = mod.export_tvm(
467+
spec={"forward": {"x": nn.spec.Tensor((1, 4), "float32")}}, debug=True
468+
)
469+
assert_structural_equal(tvm_mod["forward"], forward, True)
470+
471+
442472
if __name__ == "__main__":
443473
tvm.testing.main()

0 commit comments

Comments
 (0)