Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
github.com/apache/tvm/issues if you hit an error with dynamic kernels.
"""
import math
import operator
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -101,6 +103,83 @@ def get_constant(
return var


def get_value(token, value_dict: Dict[str, tvm.tir.SizeVar]) -> Union[int, tvm.tir.SizeVar]:
"""Converts to token to an integer value if it a constant, otherwise it generates a SizeVar

Parameters
----------
token: str
current token to decode.

value_dict: Dict
The Dictionary mapping from the name of ValueInfoProto to SizeVar.

Returns
-------
Union[int, tvm.tir.SizeVar]
The decoded token
"""

try:
return int(token)
except ValueError:
if token not in value_dict or token == "?":
value_dict[token] = tvm.tir.SizeVar(token, "int64")
value = value_dict[token]
return value


def parse_shape_name(
name: str, value_dict: Dict[str, tvm.tir.SizeVar]
) -> Union[tir.PrimExpr, tvm.tir.SizeVar]:
"""Converts expressions in the shape dimension name to prim expressions.

Parameters
----------
name: str
name of shape dimension.

value_dict: Dict
The Dictionary mapping from the name of ValueInfoProto to SizeVar.

Returns
-------
Union[tir.PrimExpr, tvm.tir.SizeVar]
The expression of the shape dimension.
"""

tokens = re.split(r"(\+|\-|\*|\/\/|\/)", name.replace(" ", ""))

operators = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.floordiv, # is floordiv since the operands are always int
"//": operator.floordiv,
}

value_stack = []
operator_stack = []

for token in tokens:
if token in operators:
operator_stack.append(token)
else:
value = get_value(token, value_dict)
if value_stack and operator_stack:
prev_value = value_stack.pop()
op = operator_stack.pop()
result = operators[op](prev_value, value)
value_stack.append(result)
else:
value_stack.append(value)

if value_stack:
return value_stack[0]
else:
raise Exception("Shape dimension could not be inferred")


def get_info(
info_proto: onnx.onnx_ml_pb2.ValueInfoProto, value_dict: Dict[str, tvm.tir.SizeVar]
) -> Tuple[str, List, str, List, Dict]:
Expand All @@ -126,9 +205,7 @@ def get_info(
name = dim.dim_param
value = dim.dim_value
if value is None or value == 0:
if name not in value_dict or name == "?":
value_dict[name] = tvm.tir.SizeVar(name, "int64")
value = value_dict[name]
value = parse_shape_name(name, value_dict)
shape_name.append(name)
else:
shape_name.append(value)
Expand All @@ -145,9 +222,7 @@ def get_info(
def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray:
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import ( # pylint: disable=import-outside-toplevel
to_array,
)
from onnx.numpy_helper import to_array # pylint: disable=import-outside-toplevel
except ImportError as exception:
raise ImportError("Unable to import onnx which is required {}".format(exception))
return to_array(tensor_proto)
Expand Down Expand Up @@ -237,6 +312,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.matmul(inputs[0], inputs[1])


def _to_numpy(x):
if isinstance(x, relax.PrimValue):
x = x.value
if isinstance(x, (tir.IntImm, tir.FloatImm)):
x = x.value
return _np.array(x)
else:
return x.data.numpy()


class BinaryBase(OnnxOpConverter):
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""

Expand All @@ -254,16 +339,8 @@ def base_impl(cls, bb, inputs, attr, params):
)
return relax.const(output, inputs[0].struct_info.dtype)
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
x = (
_np.array(inputs[0].value)
if isinstance(inputs[0], relax.PrimValue)
else inputs[0].data.numpy()
)
y = (
_np.array(inputs[1].value)
if isinstance(inputs[1], relax.PrimValue)
else inputs[1].data.numpy()
)
x = _to_numpy(inputs[0])
y = _to_numpy(inputs[1])
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable

return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable
Expand Down
203 changes: 203 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tvm.relax.frontend.onnx import from_onnx
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)
Expand Down Expand Up @@ -2752,5 +2753,207 @@ def test_params_names_start_with_onnx():
check_correctness(model)


def test_shape_dim_string_expression():
def _verify(x_shape, example_shape):

identity_node = helper.make_node("Identity", ["x"], ["y"])

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(
graph, producer_name="test_var_shape_dim_containing_expressions_onnx"
)

inputs = {"x": generate_random_value(example_shape, TensorProto.FLOAT)}
check_correctness(model, inputs)

_verify(["A", "B", "A + B"], [3, 9, 12])
_verify(["A", "B", "A - B"], [9, 3, 6])
_verify(["A", "B", "A * B"], [9, 3, 27])
_verify(["A", "B", "A // B"], [9, 3, 3])


def test_shape_dim_string_expression_graph_add():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A + B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A + B"), dtype="float32")) -> R.Tensor(("A", "B", "A + B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A + B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_subtract():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A - B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A - B"), dtype="float32")) -> R.Tensor(("A", "B", "A - B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A - B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_mul():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A * B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A * B"), dtype="float32")) -> R.Tensor(("A", "B", "A * B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A * B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_div_1():

identity_node = helper.make_node("Identity", ["x"], ["y"])

# this will result in a floordiv despite not using // since the operands are always int
x_shape = ["A", "B", "A / B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A // B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_div_2():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A // B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A // B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading