diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index d52f3ba6eca5..b16374028b1a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1226,6 +1226,49 @@ def __call__(self, *args, **kwargs): pass +class RequantizeRewriter(DFPatternCallback): + """Convert ethos-u.requantize composite function to an identity operation.""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.RequantizeParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.RequantizeParams(post.op.body) + params.ifm.tensor = post.args[0] + + lut = relay.const([], "int8") + + return ethosu_ops.ethosu_identity( + ifm=params.ifm.tensor, + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeRequantize: + """This is the pass that wraps RequantizeRewriter.""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(RequantizeRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1255,6 +1298,7 @@ def transform_module( mod = LegalizeMean()(mod) mod = LegalizeConcat()(mod) mod = LegalizeSigmoid()(mod) + mod = LegalizeRequantize()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 73007cffe726..9ea1e2bb1fc3 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1145,6 +1145,60 @@ def split_pattern(): return split +class RequantizeParams: + """ + This class will parse a call to ethos-u.requantize composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.requantize" + + def __init__(self, func_body: Call): + from tvm.relay.backend.contrib.ethosu.util import RequantArgs + + layout = "NHWC" + in_var = func_body.args[0] + requantize = func_body + + self.ifm = TensorParams( + in_var, + layout=layout, + scale=requantize.args[RequantArgs.IFM_SCALE.value], + zero_point=requantize.args[RequantArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + requantize, + layout=layout, + scale=requantize.args[RequantArgs.OFM_SCALE.value], + zero_point=requantize.args[RequantArgs.OFM_ZERO_POINT.value], + ) + + attrs = requantize.attrs + self.out_dtype = attrs.out_dtype + + def is_valid(self) -> bool: + """ + Checks whether qnn.requantize has compatible attributes with HW. + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + return False + if not check_dimensions(self.ifm) or not check_dimensions(self.ofm): + return False + if self.out_dtype and self.out_dtype != "int8": + return False + return True + + +def requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.requantize. + """ + return is_op("qnn.requantize")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant() + ) + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1230,6 +1284,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal split_pattern(), lambda pat: SplitParams(pat).is_valid(), ), + ( + RequantizeParams.composite_name, + requantize_pattern(), + lambda pat: RequantizeParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 4042bb057bd3..1af8a60158fb 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -950,7 +950,10 @@ def concat_func(*inputs): op = tf.concat(list(inputs), axis) return op - _compare_tvm_with_tflite(concat_func, shapes, accel_type) + # TODO(lhutton1) For now output is not bit exact with TFLite. + # This is because TFLite reference kernels are not being used. + # For this, TFLite will need upgrading to 2.6. + _compare_tvm_with_tflite(concat_func, shapes, accel_type, output_tolerance=1) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -987,5 +990,35 @@ def split_func(x): _compare_tvm_with_tflite(split_func, [ifm_shape], accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp", + [ + [(1, 8, 8, 3), 1.0, 0, 1.0, 0], + [(1, 20, 30, 3), 1.345, 34, 0.32, -23], + ], +) +def test_ethosu_requantize(accel_type, ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp): + dtype = "int8" + + def create_model(): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + requantize = relay.qnn.op.requantize( + ifm, + relay.const(ifm_scale, dtype="float32"), + relay.const(ifm_zp, dtype="int32"), + relay.const(ofm_scale, dtype="float32"), + relay.const(ofm_zp, dtype="int32"), + ) + return tvm.IRModule.from_expr(relay.Function([ifm], requantize)) + + cpu_mod = create_model() + input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype)} + output_data = generate_ref_data(cpu_mod, input_data) + ethosu_mod = partition_for_ethosu(cpu_mod) + + _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9f979153f714..f05fec9d124b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -21,6 +21,7 @@ pytest.importorskip("ethosu.vela") import math + import numpy as np import tensorflow as tf import tflite.Model @@ -1502,5 +1503,104 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp", + [[(1, 8, 8, 3), 1.0, 0, 1.0, 0], [(1, 20, 30, 3), 1.345, 34, 0.32, -23]], +) +def test_ethosu_requantize(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp): + dtype = "int8" + + def create_model(): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + requantize = relay.qnn.op.requantize( + ifm, + relay.const(ifm_scale, dtype="float32"), + relay.const(ifm_zp, dtype="int32"), + relay.const(ofm_scale, dtype="float32"), + relay.const(ofm_zp, dtype="int32"), + ) + return tvm.IRModule.from_expr(relay.Function([ifm], requantize)) + + def verify(ext_func): + op = ext_func.body + + # Check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + + # Check OFM + ofm = op.checked_type + assert list(ofm.shape) == list(ifm_shape) + assert str(ofm.dtype) == dtype + + # Check quantization params + assert math.isclose(op.attrs.ifm_scale, ifm_scale, abs_tol=1e-7) + assert op.attrs.ifm_zero_point == ifm_zp + assert math.isclose(op.attrs.ofm_scale, ofm_scale, abs_tol=1e-7) + assert op.attrs.ofm_zero_point == ofm_zp + + rewriter = legalize.RequantizeRewriter() + pattern_table = [ + ( + ethosu.RequantizeParams.composite_name, + ethosu.requantize_pattern(), + lambda pat: ethosu.RequantizeParams(pat).is_valid(), + ), + ] + + mod = create_model() + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +def test_multiple_requantize_offload(): + """ + Testing requantize offload in the case one requantize operation is part of + an existing pattern (in this case Mean: cast->mean->requantize) and the + other is a stand-alone requantize. + """ + + def create_model(): + ifm = relay.var("input", shape=(1, 3, 3, 4), dtype="int8") + cast = relay.cast(ifm, dtype="int32") + mean = relay.mean(cast, axis=1, keepdims=True) + requantize = relay.qnn.op.requantize( + mean, + input_scale=relay.const(1.0, dtype="float32"), + input_zero_point=relay.const(0, dtype="int32"), + output_scale=relay.const(1.0, dtype="float32"), + output_zero_point=relay.const(0, dtype="int32"), + ) + requantize = relay.qnn.op.requantize( + requantize, + input_scale=relay.const(1.0, dtype="float32"), + input_zero_point=relay.const(0, dtype="int32"), + output_scale=relay.const(1.0, dtype="float32"), + output_zero_point=relay.const(0, dtype="int32"), + ) + return tvm.IRModule.from_expr(relay.Function([ifm], requantize)) + + def verify(ext_func): + # If mean operation and separate requantize were offloaded correctly, + # there should only be a pooling operation followed by an identity + # operation leagalized. + op = ext_func.body + assert op.op.name == "contrib.ethosu.identity" + op = op.args[0] + assert ext_func.body.args[0].op.name == "contrib.ethosu.pooling" + op = op.args[0] + assert isinstance(op, relay.Var) + + mod = create_model() + mod = ethosu.partition_for_ethosu(mod) + mod = legalize.LegalizeEthosU()(mod) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])