Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,83 @@ def calculate_lut_value(i):
return identity


class HardSwishRewriter(DFPatternCallback):
"""Convert ethosu.hard_swish composite function to add operation with LUT."""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.params_class = ethosu_patterns.HardSwishParams
self.pattern = wildcard().has_attr({"Composite": self.params_class.composite_name})(
wildcard()
)

def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]

# The calculation of the LUT values is similar to that in Vela
# convert_hardswish_to_lut(op, arch, nng)
# (https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#719) # pylint: disable=line-too-long
input_scale = np.double(params.ifm.q_params.scale_f32)
input_zp = int(params.ifm.q_params.zero_point)
hires_input_scale = (1 / 128) * input_scale

output_scale = np.double(params.ofm.q_params.scale_f32)
output_zp = int(params.ofm.q_params.zero_point)
output_scale, output_shift = scaling.quantise_scale(hires_input_scale / output_scale)
output_scale_16 = fp_math.downscale_multiplier_int32_to_int16(output_scale)
output_shift = 31 - output_shift
output_shift = -output_shift if output_shift < 0 else 0

dtype = params.ifm.dtype
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max

def calculate_relu_multiplier(inp, input_scale):
rmultiplier = np.double(3 / 32768)
rscale, rshift = scaling.quantise_scale(input_scale / rmultiplier)
rscale_16 = fp_math.downscale_multiplier_int32_to_int16(rscale)

rvalue = np.int16(inp)
if rshift < 31:
rvalue = fp_math.shift_left16(rvalue, 30 - rshift)
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
rvalue = fp_math.shift_left16(rvalue, 1)
elif rshift > 31:
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
rvalue = fp_math.rounding_divide_by_pot(rvalue, rshift - 31)
else:
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)

rvalue = (rvalue + (1 << 15)) >> 1
return rvalue

def calculate_lut_values(i):
hires_input_value = (i - input_zp) * 128
preshift_input_value = fp_math.saturating_rounding_mul16(
hires_input_value, output_scale_16
)
relu_value = calculate_relu_multiplier(hires_input_value, hires_input_scale)
lut_result = fp_math.saturating_mul16(relu_value, preshift_input_value)
lut_result = fp_math.rounding_divide_by_pot(lut_result, output_shift) + output_zp
return min(qmax, max(qmin, lut_result))

values = list(map(calculate_lut_values, range(-128, 128)))
lut = relay.const(values, dtype=dtype)

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation="LUT",
)

return identity


class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""

Expand Down Expand Up @@ -1306,6 +1383,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
ShlRewriter(),
AbsRewriter(),
TanhRewriter(),
HardSwishRewriter(),
LeakyReLURewriter(),
MeanRewriter(),
ConcatRewriter(),
Expand Down
53 changes: 53 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,54 @@ def qnn_fc_pattern():
return optional_clip


class HardSwishParams:
"""
This class will parse a call to a ethos-u.hard_swish composite function
and extract the parameter information.
"""

composite_name = "ethos-u.hard_swish"

def __init__(self, func_body):
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs

quantize = func_body
divide = quantize.args[0]
multiply = divide.args[0]
clip = multiply.args[1]
add = clip.args[0]
dequantize = add.args[0]

self.ifm = TensorParams(
dequantize.args[0],
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
quantize,
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
)

def is_valid(self):
tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
return False
return True


def hard_swish_pattern():
"""Create the pattern for hard swish."""
dequantize = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
add = is_op("add")(dequantize, is_constant())
clip = is_op("clip")(add)
multiply = is_op("multiply")(dequantize, clip)
divide = is_op("divide")(multiply, is_constant())
quantize = is_op("qnn.quantize")(divide, is_constant(), is_constant())
return quantize


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1844,6 +1892,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
squeeze_pattern(),
lambda pat: SqueezeParams(pat).is_valid(),
),
(
HardSwishParams.composite_name,
hard_swish_pattern(),
lambda pat: HardSwishParams(pat).is_valid(),
),
]


Expand Down
15 changes: 15 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,21 @@ def tanh_func(x):
)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
def test_tflite_hard_swish(accel_type, ifm_shape):
np.random.seed(0)

@tf.function
def hard_swish_func(x):
op = tf.keras.layers.Lambda(
lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0
)(x)
return op

infra.compare_tvm_with_tflite(hard_swish_func, [ifm_shape], accel_type, ranges=[(-1, 1)])


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"shapes, axis",
Expand Down
55 changes: 55 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,5 +2751,60 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
def test_tflite_hard_swish(ifm_shape):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x):
op = tf.keras.layers.Lambda(
lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0
)(x)
return op

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, tf.float32)
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

return tflite_model

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)

mod = ethosu.partition_for_ethosu(mod, params)
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.HardSwishRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)
mod = relay.transform.InferType()(mod)

func_body = mod["tvmgen_default_ethos_u_main_0"].body
assert func_body.op.name == "contrib.ethosu.identity"
assert func_body.attrs.activation == "LUT"
assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
assert tuple(func_body.args[1].checked_type.shape) == (256,)


if __name__ == "__main__":
pytest.main([__file__])