diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index e46553203fa4..45428692b830 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2265,7 +2265,7 @@ def sort(x: Tensor, axis: int = -1, descending: bool = False, name="sort"): out : Tensor The sorted tensor. """ - return wrap_nested(_op.sort(x, axis, descending), name=name) + return wrap_nested(_op.sort(x._expr, axis, descending), name=name) def argsort( @@ -2296,7 +2296,7 @@ def argsort( out : Tensor The indices of the sorted tensor. """ - return wrap_nested(_op.argsort(data, axis, descending, dtype), name=name) + return wrap_nested(_op.argsort(data._expr, axis, descending, dtype), name=name) def topk( @@ -2344,7 +2344,7 @@ def topk( out : Tensor or Tuple[Tensor, Tensor] The computed result. """ - return wrap_nested(_op.topk(data, k, axis, ret_type, largest, dtype), name=name) + return wrap_nested(_op.topk(data._expr, k, axis, ret_type, largest, dtype), name=name) def multinomial_from_uniform( diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 7d78e47c945b..8bf52d7918e5 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -1188,5 +1188,34 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d ) +def test_sort_argsort_topk(): + class Model(Module): + def foo(self, x: Tensor): + z0 = op.sort(x, axis=-1, descending=True) + z1 = op.argsort(x, axis=-1, descending=False) + z2 = op.topk(x, k=2, axis=-1) + return z0, z1, z2 + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor(("seq_len", 64), dtype="float16")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + sort = R.sort(x, axis=-1, descending=True) + argsort = R.argsort(x, axis=-1, descending=False, dtype="int32") + topk = R.topk(x, k=2, axis=-1, ret_type="both", largest=True, dtype="int32") + topk_0 = topk[0] + topk_1 = topk[1] + gv = sort, argsort, (topk_0, topk_1) + R.output(gv) + return gv + + m = Model() + mod, _ = m.export_tvm({"foo": {"x": spec.Tensor(("seq_len", 64), "float16")}}) + + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()