diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index cbd633324a75..eb7a3eaf3628 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -35,6 +35,8 @@ github.com/apache/tvm/issues if you hit an error with dynamic kernels. """ import math +import operator +import re import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -101,6 +103,83 @@ def get_constant( return var +def get_value(token, value_dict: Dict[str, tvm.tir.SizeVar]) -> Union[int, tvm.tir.SizeVar]: + """Converts to token to an integer value if it a constant, otherwise it generates a SizeVar + + Parameters + ---------- + token: str + current token to decode. + + value_dict: Dict + The Dictionary mapping from the name of ValueInfoProto to SizeVar. + + Returns + ------- + Union[int, tvm.tir.SizeVar] + The decoded token + """ + + try: + return int(token) + except ValueError: + if token not in value_dict or token == "?": + value_dict[token] = tvm.tir.SizeVar(token, "int64") + value = value_dict[token] + return value + + +def parse_shape_name( + name: str, value_dict: Dict[str, tvm.tir.SizeVar] +) -> Union[tir.PrimExpr, tvm.tir.SizeVar]: + """Converts expressions in the shape dimension name to prim expressions. + + Parameters + ---------- + name: str + name of shape dimension. + + value_dict: Dict + The Dictionary mapping from the name of ValueInfoProto to SizeVar. + + Returns + ------- + Union[tir.PrimExpr, tvm.tir.SizeVar] + The expression of the shape dimension. + """ + + tokens = re.split(r"(\+|\-|\*|\/\/|\/)", name.replace(" ", "")) + + operators = { + "+": operator.add, + "-": operator.sub, + "*": operator.mul, + "/": operator.floordiv, # is floordiv since the operands are always int + "//": operator.floordiv, + } + + value_stack = [] + operator_stack = [] + + for token in tokens: + if token in operators: + operator_stack.append(token) + else: + value = get_value(token, value_dict) + if value_stack and operator_stack: + prev_value = value_stack.pop() + op = operator_stack.pop() + result = operators[op](prev_value, value) + value_stack.append(result) + else: + value_stack.append(value) + + if value_stack: + return value_stack[0] + else: + raise Exception("Shape dimension could not be inferred") + + def get_info( info_proto: onnx.onnx_ml_pb2.ValueInfoProto, value_dict: Dict[str, tvm.tir.SizeVar] ) -> Tuple[str, List, str, List, Dict]: @@ -126,9 +205,7 @@ def get_info( name = dim.dim_param value = dim.dim_value if value is None or value == 0: - if name not in value_dict or name == "?": - value_dict[name] = tvm.tir.SizeVar(name, "int64") - value = value_dict[name] + value = parse_shape_name(name, value_dict) shape_name.append(name) else: shape_name.append(value) @@ -145,9 +222,7 @@ def get_info( def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray: """Grab data in TensorProto and convert to numpy array.""" try: - from onnx.numpy_helper import ( # pylint: disable=import-outside-toplevel - to_array, - ) + from onnx.numpy_helper import to_array # pylint: disable=import-outside-toplevel except ImportError as exception: raise ImportError("Unable to import onnx which is required {}".format(exception)) return to_array(tensor_proto) @@ -237,6 +312,16 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) +def _to_numpy(x): + if isinstance(x, relax.PrimValue): + x = x.value + if isinstance(x, (tir.IntImm, tir.FloatImm)): + x = x.value + return _np.array(x) + else: + return x.data.numpy() + + class BinaryBase(OnnxOpConverter): """Converts an onnx BinaryBase node into an equivalent Relax expression.""" @@ -254,16 +339,8 @@ def base_impl(cls, bb, inputs, attr, params): ) return relax.const(output, inputs[0].struct_info.dtype) if any([isinstance(inp, relax.PrimValue) for inp in inputs]): - x = ( - _np.array(inputs[0].value) - if isinstance(inputs[0], relax.PrimValue) - else inputs[0].data.numpy() - ) - y = ( - _np.array(inputs[1].value) - if isinstance(inputs[1], relax.PrimValue) - else inputs[1].data.numpy() - ) + x = _to_numpy(inputs[0]) + y = _to_numpy(inputs[1]) return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index c130bf43730b..6f74957a0781 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -35,6 +35,7 @@ from tvm.relax.frontend.onnx import from_onnx from tvm.script import relax as R from tvm.script import tir as T +from tvm.script import ir as I bg = np.random.MT19937(0) rg = np.random.Generator(bg) @@ -2752,5 +2753,207 @@ def test_params_names_start_with_onnx(): check_correctness(model) +def test_shape_dim_string_expression(): + def _verify(x_shape, example_shape): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model( + graph, producer_name="test_var_shape_dim_containing_expressions_onnx" + ) + + inputs = {"x": generate_random_value(example_shape, TensorProto.FLOAT)} + check_correctness(model, inputs) + + _verify(["A", "B", "A + B"], [3, 9, 12]) + _verify(["A", "B", "A - B"], [9, 3, 6]) + _verify(["A", "B", "A * B"], [9, 3, 27]) + _verify(["A", "B", "A // B"], [9, 3, 3]) + + +def test_shape_dim_string_expression_graph_add(): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + x_shape = ["A", "B", "A + B"] + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx") + + tvm_model = from_onnx(model, opset=14, keep_params_in_input=True) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("A", "B", "A + B"), dtype="float32")) -> R.Tensor(("A", "B", "A + B"), dtype="float32"): + A = T.int64(is_size_var=True) + B = T.int64(is_size_var=True) + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((A, B, A + B), dtype="float32") = x + R.output(gv) + return gv + # fmt: on + + tvm.ir.assert_structural_equal(tvm_model, Expected) + + +def test_shape_dim_string_expression_graph_subtract(): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + x_shape = ["A", "B", "A - B"] + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx") + + tvm_model = from_onnx(model, opset=14, keep_params_in_input=True) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("A", "B", "A - B"), dtype="float32")) -> R.Tensor(("A", "B", "A - B"), dtype="float32"): + A = T.int64(is_size_var=True) + B = T.int64(is_size_var=True) + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((A, B, A - B), dtype="float32") = x + R.output(gv) + return gv + # fmt: on + + tvm.ir.assert_structural_equal(tvm_model, Expected) + + +def test_shape_dim_string_expression_graph_mul(): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + x_shape = ["A", "B", "A * B"] + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx") + + tvm_model = from_onnx(model, opset=14, keep_params_in_input=True) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("A", "B", "A * B"), dtype="float32")) -> R.Tensor(("A", "B", "A * B"), dtype="float32"): + A = T.int64(is_size_var=True) + B = T.int64(is_size_var=True) + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((A, B, A * B), dtype="float32") = x + R.output(gv) + return gv + # fmt: on + + tvm.ir.assert_structural_equal(tvm_model, Expected) + + +def test_shape_dim_string_expression_graph_div_1(): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + # this will result in a floordiv despite not using // since the operands are always int + x_shape = ["A", "B", "A / B"] + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx") + + tvm_model = from_onnx(model, opset=14, keep_params_in_input=True) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"): + A = T.int64(is_size_var=True) + B = T.int64(is_size_var=True) + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((A, B, A // B), dtype="float32") = x + R.output(gv) + return gv + # fmt: on + + tvm.ir.assert_structural_equal(tvm_model, Expected) + + +def test_shape_dim_string_expression_graph_div_2(): + + identity_node = helper.make_node("Identity", ["x"], ["y"]) + + x_shape = ["A", "B", "A // B"] + + graph = helper.make_graph( + [identity_node], + "test_var_shape_dim_containing_expressions_onnx", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx") + + tvm_model = from_onnx(model, opset=14, keep_params_in_input=True) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"): + A = T.int64(is_size_var=True) + B = T.int64(is_size_var=True) + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((A, B, A // B), dtype="float32") = x + R.output(gv) + return gv + # fmt: on + + tvm.ir.assert_structural_equal(tvm_model, Expected) + + if __name__ == "__main__": tvm.testing.main()