Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,17 @@ def base_impl(cls, bb, inputs, attr, params):
"""Base implementation for binary operations."""
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
output = cls.numpy_op( # pylint: disable=not-callable
inputs[0].data.numpy(), inputs[1].data.numpy()
)
return relax.const(output, inputs[0].struct_info.dtype)
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in inputs]):
x = _to_numpy(inputs[0])
y = _to_numpy(inputs[1])
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable
output = cls.numpy_op(x, y) # pylint: disable=not-callable
if x.dtype == y.dtype:
# no numpy precision widening
output = output.astype(x.dtype)
if all([isinstance(inp, relax.Constant) for inp in inputs]):
return relax.const(output, output.dtype) # pylint: disable=not-callable
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
return relax.PrimValue(output.item()) # pylint: disable=not-callable

return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable

Expand Down
37 changes: 29 additions & 8 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import onnx
import onnxruntime
import pytest
from onnx import ModelProto, TensorProto, helper, mapping
from onnx import ModelProto, TensorProto, helper

import tvm
import tvm.testing
Expand Down Expand Up @@ -62,7 +62,7 @@ def generate_random_inputs(
def generate_random_value(shape, elem_type) -> np.ndarray:
# Extract datatype for the input.
if elem_type:
dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
else:
dtype = "float32"

Expand All @@ -87,6 +87,7 @@ def check_correctness(
opset: int = 14,
rtol: float = 1e-7,
atol: float = 1e-5,
check_dtypes: bool = False,
) -> None:
"""Run an onnx model in both onnxruntime and TVM through our importer
confirm that the results match. Otherwise, an exception will be raised.
Expand All @@ -104,6 +105,8 @@ def check_correctness(
atol: float
Set the tolerance of correctness checking. Some ops may be show more
arithmetic variance than others.
check_dtypes: bool
Check if data types are the same.
"""
# Configure model format.
if ir_version is not None:
Expand Down Expand Up @@ -152,17 +155,35 @@ def check_correctness(
# while the ONNX output number is one, which is a list
tvm_output = [tvm_output]

def _get_numpy_subdtype(narray):
if np.issubdtype(narray.dtype, np.integer):
return "integer"
elif np.issubdtype(narray.dtype, np.floating):
return "floating"
elif np.issubdtype(narray.dtype, np.bool_):
return "bool"
elif np.issubdtype(narray.dtype, np.complexfloating):
return "complexfloating"
else:
return "other"

def _check_output(tvm_out, ort_out):
if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
_check_output(tvm_out_i, ort_out_i)
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
if check_dtypes:
assert tvm_out.numpy().dtype == ort_out.dtype
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
shape_out = tvm.nd.array([int(i) for i in tvm_out])
if check_dtypes:
assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
if check_dtypes:
assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
else:
raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
Expand Down Expand Up @@ -267,7 +288,7 @@ def verify_binary(
)

model = helper.make_model(graph, producer_name="binary_test")
check_correctness(model, opset=opset)
check_correctness(model, opset=opset, check_dtypes=True)


def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14):
Expand All @@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32
)

model = helper.make_model(graph, producer_name="binary_test")
check_correctness(model, opset=opset)
check_correctness(model, opset=opset, check_dtypes=True)


def verify_compare(op_name, shape, attrs={}, domain=None):
Expand Down Expand Up @@ -1897,7 +1918,7 @@ def verify_constantofshape(input_dim, value, dtype):
["input"],
["output"],
value=helper.make_tensor(
"value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,)
"value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), (1,), (value,)
),
)

Expand All @@ -1917,7 +1938,7 @@ def verify_constantofshape(input_dim, value, dtype):
],
outputs=[
helper.make_tensor_value_info(
"output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim
"output", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim
)
],
)
Expand Down Expand Up @@ -2299,7 +2320,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o

inputs = [
helper.make_tensor_value_info(
"input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], indata_shape
"input", helper.np_dtype_to_tensor_dtype(indata.dtype), indata_shape
)
]

Expand Down Expand Up @@ -2333,7 +2354,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
outputs=[
helper.make_tensor_value_info(
f"output_{i}",
mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
helper.np_dtype_to_tensor_dtype(indata.dtype),
list(outdata_shapes[i]),
)
for i in range(len(split_index))
Expand Down
Loading