diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 76761fd78325..e62334132ecc 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -24,6 +24,7 @@ from __future__ import absolute_import from .mxnet import from_mxnet +from .mxnet_qnn_op_utils import dequantize_mxnet_min_max from .keras import from_keras from .onnx import from_onnx from .tflite import from_tflite diff --git a/python/tvm/relay/frontend/mxnet_qnn_op_utils.py b/python/tvm/relay/frontend/mxnet_qnn_op_utils.py new file mode 100644 index 000000000000..e2aaa794b577 --- /dev/null +++ b/python/tvm/relay/frontend/mxnet_qnn_op_utils.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return +"""MXNet qnn dialect helper methods for MXNet specific implementations of more + generic qnn supported ops. +""" + +import numpy as np +from tvm.relay.qnn.op.qnn import dequantize + +zero_centered_uint8_quantized_range = np.float32(255) +zero_centered_int8_quantized_range = np.float32(127) + + +def _dequantize_zero_centered(data, + data_min, + data_max, + quantized_range): + r"""Dequantizes the given data tensor by calculating the scale + using the MKLDNN formula `max(abs(data_min, data_max))/quantized_range`. + Where quantized_range is 255 for uint8 and 127 for int8. The `data_min` + and `data_max` are the min and max to use for the `data` tensor elements. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type {int8 or uint8}. + data_min : float + The minimum to use data elements. + data_max : float + The maximum to use for data elements. + quantized_range : float + 255 for uint8 and 127 for int8. This is the data type range. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + real_range = np.max([np.abs(np.float32(data_min)), + np.abs(np.float32(data_max))]) + scale = np.divide(real_range, quantized_range) + zero_point = 0 + return dequantize(data, scale, zero_point) + + +def _dequantize_mkldnn_min_max_int8(data, + imin_range, + imax_range): + r"""Dequantizes the given `data` in {int8 or uint8} and the given + min and max ranges and the output data type is `float32`. + The method of dequantizing is described here - https://tinyurl.com/y5k6fz5w. + We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 + but compute the `scale` and `zero_point` to fit our equation. + Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN + stores the min and max from which we calculate the scale and zero_point. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + imin_range : float + The minimum to use data elements. + imax_range : float + The maximum to use for data elements. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _dequantize_zero_centered(data, + data_min=imin_range, + data_max=imax_range, + quantized_range=zero_centered_int8_quantized_range) + + +def _dequantize_mkldnn_min_max_uint8(data, + imin_range, + imax_range): + r"""Dequantizes the given `data` in {int8 or uint8} and the given + min and max ranges and the output data type is `float32`. + The method of dequantize is described here - https://tinyurl.com/y5k6fz5w. + We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 + but compute the `scale` and `zero_point` to fit our equation. + Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN + stores the min and max from which we calculate the scale and zero_point. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + imin_range : float + The minimum to use data elements. + imax_range : float + The maximum to use for data elements. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _dequantize_zero_centered(data, + data_min=imin_range, + data_max=imax_range, + quantized_range=zero_centered_uint8_quantized_range) + + +def _dequantize_mxnet_min_max_int8(data, + imin_range, + imax_range): + r"""Deuantizes the given `data` in {int8 or uint8} and the given + min and max ranges and the output data type is `float32`. + The method of dequantization is described here - https://tinyurl.com/y4d7hrzf. + We use our default dequantize implementation from src/relay/qnn/op/dequantize.cc:67 + but compute the `scale` and `zero_point` to fit our equation. + Unlike in TFLite where we get the scale and zero_point from the model, Mxnet + stores the min and max from which we calculate the scale and zero_point. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + imin_range : float + The minimum to use data elements. + imax_range : float + The maximum to use for data elements. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _dequantize_zero_centered(data, + data_min=imin_range, + data_max=imax_range, + quantized_range=zero_centered_int8_quantized_range) + + +def _dequantize_mxnet_min_max_uint8(data, + imin_range, + imax_range): + r"""Dequantizes the given `data` in {int8 or uint8} and the given + min and max ranges and the output data type is `float32`. + The method of dequantizing is described here - https://tinyurl.com/y4d7hrzf. + We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 + but compute the `scale` and `zero_point` to fit our equation. + Unlike in TFLite where we get the scale and zero_point from the model, Mxnet + stores the min and max from which we calculate the scale and zero_point. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + imin_range : float + The minimum to use data elements. + imax_range : float + The maximum to use for data elements. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + iinfo = np.iinfo(np.uint8) + min_limit = np.float64(iinfo.min) + max_limit = np.float64(iinfo.max) + imin_range = np.float64(imin_range) + imax_range = np.float64(imax_range) + scale = np.divide((imax_range - imin_range), + (max_limit - min_limit)) + zero_point = np.int(-1 * np.divide(imin_range, scale)) + return dequantize(data, scale, zero_point) + + +def dequantize_mxnet_min_max(data, + min_range, + max_range, + in_dtype='int8', + use_mkldnn=False): + r"""Dequantizes the given `data` in {int8 or uint8} and the given + min and max ranges. The output data type is float32. + Only `float32` is supported as output data types. + The input data type is expected to be {int8 or uint8}. + Mxnet has two different flavors for dequantization 1) Default 2)MKLDNN. + To get the second one Mxnet must be built with MKLDNN during compile time. + Users can choose either of the implementation for TVM runtime. + The main difference between the two implementation is that MKLDNN is centered + around 0 and the default implementation for uint8 is not. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + min_range : float + The minimum to use data elements for the output. + max_range : float + The maximum to use for data elements for the output. + in_dtype: str, optional + The input data type, can be 'int8' or 'uint8' + use_mkldnn: bool, optional + If True then uses MKLDNN quantization implementation otherwise + will use default implementation. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + if in_dtype == 'uint8': + if use_mkldnn: + return _dequantize_mkldnn_min_max_uint8(data, + min_range, + max_range) + else: + return _dequantize_mxnet_min_max_uint8(data, + min_range, + max_range) + elif in_dtype == 'int8': + if use_mkldnn: + return _dequantize_mkldnn_min_max_int8(data, min_range, max_range) + else: + return _dequantize_mxnet_min_max_int8(data, min_range, max_range) + else: + raise ValueError( + "Expected out_dtype to be int8 or uint8 but was %s" % in_dtype) diff --git a/tests/python/frontend/mxnet/test_qnn_ops_utils.py b/tests/python/frontend/mxnet/test_qnn_ops_utils.py new file mode 100644 index 000000000000..78c9692ea5b3 --- /dev/null +++ b/tests/python/frontend/mxnet/test_qnn_ops_utils.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime + + +def test_mxnet_dequantize_op(): + + def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + min_range = quant_args['min_range'] + max_range = quant_args['max_range'] + quantized_output = \ + relay.frontend.dequantize_mxnet_min_max(input_data, + min_range=min_range, + max_range=max_range, + in_dtype=in_dtype) + mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(input_data=in_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + assert np.allclose(res, verify_output_data, ) + assert res.dtype == np.float32 + + def test_uint8_to_float32(): + data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ + .astype('uint8') \ + .reshape((2, 5)) + output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + .astype('float32') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} + quantize_test_driver(in_dtype='uint8', + quant_args=quant_args, + in_data=data, + verify_output_data=output) + + def test_int8_to_float32(): + data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \ + .astype('int8') \ + .reshape((2, 5)) + output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316, + 61.984253, 62.48819, 62.992126, 63.496063, 64.]) \ + .astype('float32') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} + quantize_test_driver(in_dtype='int8', + quant_args=quant_args, + in_data=data, + verify_output_data=output) + + test_uint8_to_float32() + test_int8_to_float32() + + +def test_mkldnn_dequantize_op(): + + def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + min_range = quant_args['min_range'] + max_range = quant_args['max_range'] + quantized_output = \ + relay.frontend.dequantize_mxnet_min_max(input_data, + min_range=min_range, + max_range=max_range, + in_dtype=in_dtype, + use_mkldnn=True) + mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + mod = relay.Module.from_expr(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(input_data=in_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + # print(res) + # np.testing.assert_equal(res, verify_output_data) + assert np.allclose(res, verify_output_data, ) + assert res.dtype == np.float32 + + def test_uint8_to_float32(): + data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ + .astype('uint8') \ + .reshape((2, 5)) + output = np.array([0., 0.2509804, 0.5019608, 0.75294125, 1.0039216, + 62.996082, 63.247063, 63.498043, 63.749023, 64.]) \ + .astype('float32') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} + quantize_test_driver(in_dtype='uint8', + quant_args=quant_args, + in_data=data, + verify_output_data=output) + + def test_int8_to_float32(): + data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \ + .astype('int8') \ + .reshape((2, 5)) + output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316, + 61.984253, 62.48819, 62.992126, 63.496063, 64.]) \ + .astype('float32') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} + quantize_test_driver(in_dtype='int8', + quant_args=quant_args, + in_data=data, + verify_output_data=output) + + test_uint8_to_float32() + test_int8_to_float32() + + +if __name__ == "__main__": + test_mxnet_dequantize_op() + test_mkldnn_dequantize_op()