diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index e16b109ab564..3f323b01a048 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -336,15 +336,17 @@ def base_impl(cls, bb, inputs, attr, params): """Base implementation for binary operations.""" if cls.numpy_op is None or cls.relax_op is None: raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = cls.numpy_op( # pylint: disable=not-callable - inputs[0].data.numpy(), inputs[1].data.numpy() - ) - return relax.const(output, inputs[0].struct_info.dtype) - if any([isinstance(inp, relax.PrimValue) for inp in inputs]): + if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in inputs]): x = _to_numpy(inputs[0]) y = _to_numpy(inputs[1]) - return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable + output = cls.numpy_op(x, y) # pylint: disable=not-callable + if x.dtype == y.dtype: + # no numpy precision widening + output = output.astype(x.dtype) + if all([isinstance(inp, relax.Constant) for inp in inputs]): + return relax.const(output, output.dtype) # pylint: disable=not-callable + if any([isinstance(inp, relax.PrimValue) for inp in inputs]): + return relax.PrimValue(output.item()) # pylint: disable=not-callable return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3d112c2f3b8a..b55489a623f0 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -27,7 +27,7 @@ import onnx import onnxruntime import pytest -from onnx import ModelProto, TensorProto, helper, mapping +from onnx import ModelProto, TensorProto, helper import tvm import tvm.testing @@ -62,7 +62,7 @@ def generate_random_inputs( def generate_random_value(shape, elem_type) -> np.ndarray: # Extract datatype for the input. if elem_type: - dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) + dtype = str(helper.tensor_dtype_to_np_dtype(elem_type)) else: dtype = "float32" @@ -87,6 +87,7 @@ def check_correctness( opset: int = 14, rtol: float = 1e-7, atol: float = 1e-5, + check_dtypes: bool = False, ) -> None: """Run an onnx model in both onnxruntime and TVM through our importer confirm that the results match. Otherwise, an exception will be raised. @@ -104,6 +105,8 @@ def check_correctness( atol: float Set the tolerance of correctness checking. Some ops may be show more arithmetic variance than others. + check_dtypes: bool + Check if data types are the same. """ # Configure model format. if ir_version is not None: @@ -152,17 +155,35 @@ def check_correctness( # while the ONNX output number is one, which is a list tvm_output = [tvm_output] + def _get_numpy_subdtype(narray): + if np.issubdtype(narray.dtype, np.integer): + return "integer" + elif np.issubdtype(narray.dtype, np.floating): + return "floating" + elif np.issubdtype(narray.dtype, np.bool_): + return "bool" + elif np.issubdtype(narray.dtype, np.complexfloating): + return "complexfloating" + else: + return "other" + def _check_output(tvm_out, ort_out): if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)): assert len(tvm_out) == len(ort_out), "Unequal number of outputs" for tvm_out_i, ort_out_i in zip(tvm_out, ort_out): _check_output(tvm_out_i, ort_out_i) elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray): + if check_dtypes: + assert tvm_out.numpy().dtype == ort_out.dtype tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): shape_out = tvm.nd.array([int(i) for i in tvm_out]) + if check_dtypes: + assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out) tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray): + if check_dtypes: + assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out) tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol) else: raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}") @@ -267,7 +288,7 @@ def verify_binary( ) model = helper.make_model(graph, producer_name="binary_test") - check_correctness(model, opset=opset) + check_correctness(model, opset=opset, check_dtypes=True) def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14): @@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32 ) model = helper.make_model(graph, producer_name="binary_test") - check_correctness(model, opset=opset) + check_correctness(model, opset=opset, check_dtypes=True) def verify_compare(op_name, shape, attrs={}, domain=None): @@ -1897,7 +1918,7 @@ def verify_constantofshape(input_dim, value, dtype): ["input"], ["output"], value=helper.make_tensor( - "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,) + "value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), (1,), (value,) ), ) @@ -1917,7 +1938,7 @@ def verify_constantofshape(input_dim, value, dtype): ], outputs=[ helper.make_tensor_value_info( - "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim + "output", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim ) ], ) @@ -2299,7 +2320,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o inputs = [ helper.make_tensor_value_info( - "input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], indata_shape + "input", helper.np_dtype_to_tensor_dtype(indata.dtype), indata_shape ) ] @@ -2333,7 +2354,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o outputs=[ helper.make_tensor_value_info( f"output_{i}", - mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], + helper.np_dtype_to_tensor_dtype(indata.dtype), list(outdata_shapes[i]), ) for i in range(len(split_index))