Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,49 @@ def __call__(self, *args, **kwargs):
pass


class RequantizeRewriter(DFPatternCallback):
"""Convert ethos-u.requantize composite function to an identity operation."""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.RequantizeParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.RequantizeParams(post.op.body)
params.ifm.tensor = post.args[0]

lut = relay.const([], "int8")

return ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeRequantize:
"""This is the pass that wraps RequantizeRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(RequantizeRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1255,6 +1298,7 @@ def transform_module(
mod = LegalizeMean()(mod)
mod = LegalizeConcat()(mod)
mod = LegalizeSigmoid()(mod)
mod = LegalizeRequantize()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
59 changes: 59 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,60 @@ def split_pattern():
return split


class RequantizeParams:
"""
This class will parse a call to ethos-u.requantize composite function
and extract the parameter information.
"""

composite_name = "ethos-u.requantize"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

layout = "NHWC"
in_var = func_body.args[0]
requantize = func_body

self.ifm = TensorParams(
in_var,
layout=layout,
scale=requantize.args[RequantArgs.IFM_SCALE.value],
zero_point=requantize.args[RequantArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
requantize,
layout=layout,
scale=requantize.args[RequantArgs.OFM_SCALE.value],
zero_point=requantize.args[RequantArgs.OFM_ZERO_POINT.value],
)

attrs = requantize.attrs
self.out_dtype = attrs.out_dtype

def is_valid(self) -> bool:
"""
Checks whether qnn.requantize has compatible attributes with HW.
"""
tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
return False
if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
return False
if self.out_dtype and self.out_dtype != "int8":
return False
return True


def requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for qnn.requantize.
"""
return is_op("qnn.requantize")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
)


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1230,6 +1284,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
split_pattern(),
lambda pat: SplitParams(pat).is_valid(),
),
(
RequantizeParams.composite_name,
requantize_pattern(),
lambda pat: RequantizeParams(pat).is_valid(),
),
]


Expand Down
35 changes: 34 additions & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,10 @@ def concat_func(*inputs):
op = tf.concat(list(inputs), axis)
return op

_compare_tvm_with_tflite(concat_func, shapes, accel_type)
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
_compare_tvm_with_tflite(concat_func, shapes, accel_type, output_tolerance=1)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand Down Expand Up @@ -987,5 +990,35 @@ def split_func(x):
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
[
[(1, 8, 8, 3), 1.0, 0, 1.0, 0],
[(1, 20, 30, 3), 1.345, 34, 0.32, -23],
],
)
def test_ethosu_requantize(accel_type, ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
dtype = "int8"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
requantize = relay.qnn.op.requantize(
ifm,
relay.const(ifm_scale, dtype="float32"),
relay.const(ifm_zp, dtype="int32"),
relay.const(ofm_scale, dtype="float32"),
relay.const(ofm_zp, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

cpu_mod = create_model()
input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype)}
output_data = generate_ref_data(cpu_mod, input_data)
ethosu_mod = partition_for_ethosu(cpu_mod)

_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


if __name__ == "__main__":
pytest.main([__file__])
100 changes: 100 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
pytest.importorskip("ethosu.vela")

import math

import numpy as np
import tensorflow as tf
import tflite.Model
Expand Down Expand Up @@ -1502,5 +1503,104 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
[[(1, 8, 8, 3), 1.0, 0, 1.0, 0], [(1, 20, 30, 3), 1.345, 34, 0.32, -23]],
)
def test_ethosu_requantize(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
dtype = "int8"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
requantize = relay.qnn.op.requantize(
ifm,
relay.const(ifm_scale, dtype="float32"),
relay.const(ifm_zp, dtype="int32"),
relay.const(ofm_scale, dtype="float32"),
relay.const(ofm_zp, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

def verify(ext_func):
op = ext_func.body

# Check IFM
ifm = op.args[0].checked_type
assert list(ifm.shape) == list(ifm_shape)
assert str(ifm.dtype) == dtype

# Check OFM
ofm = op.checked_type
assert list(ofm.shape) == list(ifm_shape)
assert str(ofm.dtype) == dtype

# Check quantization params
assert math.isclose(op.attrs.ifm_scale, ifm_scale, abs_tol=1e-7)
assert op.attrs.ifm_zero_point == ifm_zp
assert math.isclose(op.attrs.ofm_scale, ofm_scale, abs_tol=1e-7)
assert op.attrs.ofm_zero_point == ofm_zp

rewriter = legalize.RequantizeRewriter()
pattern_table = [
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]

mod = create_model()
mod = partition_ethosu_by_table(mod, pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
verify(mod["tvmgen_default_ethos_u_main_0"])


def test_multiple_requantize_offload():
"""
Testing requantize offload in the case one requantize operation is part of
an existing pattern (in this case Mean: cast->mean->requantize) and the
other is a stand-alone requantize.
"""

def create_model():
ifm = relay.var("input", shape=(1, 3, 3, 4), dtype="int8")
cast = relay.cast(ifm, dtype="int32")
mean = relay.mean(cast, axis=1, keepdims=True)
requantize = relay.qnn.op.requantize(
mean,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
requantize = relay.qnn.op.requantize(
requantize,
input_scale=relay.const(1.0, dtype="float32"),
input_zero_point=relay.const(0, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(0, dtype="int32"),
)
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))

def verify(ext_func):
# If mean operation and separate requantize were offloaded correctly,
# there should only be a pooling operation followed by an identity
# operation leagalized.
op = ext_func.body
assert op.op.name == "contrib.ethosu.identity"
op = op.args[0]
assert ext_func.body.args[0].op.name == "contrib.ethosu.pooling"
op = op.args[0]
assert isinstance(op, relay.Var)

mod = create_model()
mod = ethosu.partition_for_ethosu(mod)
mod = legalize.LegalizeEthosU()(mod)
verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
pytest.main([__file__])