Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
not refer to an Op. Else, a new call node with a new operator.
"""
new_call = call
lut_activations = ["TANH", "LUT"]
lut_activations = ["TANH", "LUT", "SIGMOID"]

if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
producer_op = call.args[0]
Expand Down
84 changes: 71 additions & 13 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
from typing import List, Type
from typing import List, Type, Callable
import math

import numpy as np # type: ignore
Expand Down Expand Up @@ -125,32 +125,36 @@ def __call__(self, *args, **kwargs):
pass


def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
"""Method to calculate the values of the tanh lookup table"""
def get_lut_from_func(
ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float]
) -> List[int]:
"""Method to calculate the values of the lookup table based on the calculation function"""
lut_values = list()
# Only int8 is currently supported
dtype = np.int8
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = math.tanh(x_real)
out_real = func(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

return lut_values


class TanhRewriter(DFPatternCallback):
"""This pass adds tanh as a LUT to the identity operator"""
class LutActivationRewriter(DFPatternCallback):
"""A class to create an identity operator with the LUT"""

def __init__(self):
def __init__(
self, params_class: Type, activation_type: str, calc_func: Callable[[float], float]
):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
)(wildcard())
self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard())
self.activation_type = activation_type
self.calc_func = calc_func

def callback(self, pre, post, node_map):
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
id_input = post.args[0]

quantize_args = post.op.body.args
Expand All @@ -161,7 +165,9 @@ def callback(self, pre, post, node_map):
input_scale = float(dequantize_args[1].data.asnumpy())
input_zp = int(dequantize_args[2].data.asnumpy())

lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
lut_values = get_lut_from_func(
input_scale, input_zp, output_scale, output_zp, self.calc_func
)
lut = relay.const(lut_values, dtype="uint8")

# We baked the requantization into the LUT, so we don't requantize the identity operator
Expand All @@ -172,12 +178,21 @@ def callback(self, pre, post, node_map):
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation="TANH",
activation=self.activation_type,
)

return identity


class TanhRewriter(LutActivationRewriter):
"""This pass adds tanh as a LUT to the identity operator"""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.TanhParams, activation_type="TANH", calc_func=math.tanh
)


@ir.transform.module_pass(opt_level=1)
class LegalizeTanh:
"""This is the pass that wraps TanhRewriter"""
Expand All @@ -194,6 +209,48 @@ def __call__(self, *args, **kwargs):
pass


def sigmoid_calc_func(x: float) -> float:
"""Function to calculate the values for sigmoid"""
# Thse limits are inherited from TFLite
upper_limit = 8.0
lower_limit = -8.0

if x <= lower_limit:
y = 0.0
elif x >= upper_limit:
y = 1.0
else:
y = 1 / (1 + math.exp(-x))
return y


class SigmoidRewriter(LutActivationRewriter):
"""This pass adds sigmoid as a LUT for identity op"""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.SigmoidParams,
activation_type="SIGMOID",
calc_func=sigmoid_calc_func,
)


@ir.transform.module_pass(opt_level=1)
class LegalizeSigmoid:
"""This is the pass that wraps SigmoidRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(SigmoidRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""

Expand Down Expand Up @@ -1196,6 +1253,7 @@ def transform_module(
mod = LegalizeTanh()(mod)
mod = LegalizeMean()(mod)
mod = LegalizeConcat()(mod)
mod = LegalizeSigmoid()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ def conv2d_compute(
"dilation_w": dilation_w,
}

has_lut = activation in ("TANH", "LUT", "SIGMOID")

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
if has_lut:
conv2d_attrs["lut"] = lut

conv = te.compute(
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,13 @@ def depthwise_conv2d_compute(
"dilation_w": dilation_w,
}

has_lut = activation in ("TANH", "LUT", "SIGMOID")

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
if has_lut:
depthwise_conv2d_attrs["lut"] = lut

depthwise = te.compute(
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def identity_compute(
dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
id_attrs = {"op": "ethosu_identity", "activation": activation}

has_lut = activation in ("TANH", "LUT", "SIGMOID")

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
if has_lut:
id_attrs["lut"] = lut

identity = te.compute(
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,13 @@ def pooling_compute(
"upscale": upscale,
}

has_lut = activation in ("TANH", "LUT", "SIGMOID")

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
if has_lut:
pooling_attrs["lut"] = lut

pooling = te.compute(
Expand Down
37 changes: 31 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return pattern


class TanhParams:
class LutActivationParams:
"""
This class will parse a call to a ethos-u.tanh composite function
and extract the parameter information.
A parent class for LUT based activation functions that extract the input and
output tensors and check whether they are valid.
"""

composite_name = "ethos-u.tanh"

def __init__(self, func_body: Call):
self.ofm = TensorParams(func_body)
self.ifm = TensorParams(func_body.args[0].args[0].args[0])

def is_valid(self):
"""
This function checks whether reshape has compatible attributes with the NPU
This function checks whether activation has compatible attributes with the NPU
"""
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
return True


class TanhParams(LutActivationParams):

composite_name = "ethos-u.tanh"


def tanh_pattern():
"""Create pattern for tanh"""
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
Expand All @@ -947,6 +950,23 @@ def tanh_pattern():
return quant


class SigmoidParams(LutActivationParams):
"""
This class will parse a call to a ethos-u.sigmoid composite function
and extract the parameter information.
"""

composite_name = "ethos-u.sigmoid"


def sigmoid_pattern():
"""Create pattern for sigmoid"""
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
sigmoid = is_op("sigmoid")(dequant)
quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant())
return quant


class MeanParams:
"""
This class will parse a call to ethosu.mean composite function
Expand Down Expand Up @@ -1162,6 +1182,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
lambda pat: MeanParams(pat).is_valid(),
),
(ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()),
(
SigmoidParams.composite_name,
sigmoid_pattern(),
lambda pat: SigmoidParams(pat).is_valid(),
),
]


Expand Down
74 changes: 17 additions & 57 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,66 +815,14 @@ def clz_comp(n):

@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_tflite_tanh(accel_type):
dtype = "int8"
ifm_shape = [1, 115, 32, 7]

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tanh_function(self, x):
op = tf.nn.tanh(x)
return op

model = Model()
concrete_func = model.tanh_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32)
)

# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

tflite_graph = create_tflite_graph()

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

relay_module, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = partition_for_ethosu(relay_module, params)

# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
@tf.function
def tanh_func(x):
op = tf.nn.tanh(x)
return op

compiled_models = infra.build_source(
mod,
input_data,
output_data,
accel_type,
)

# Assumes only two runtime.Modules are created -- i.e. single offload module
ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]

# Verify generated C source
get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
compilation_artifacts = get_artifacts(ethosu_module)
cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
infra.print_payload(cmms)
infra.verify_source(compiled_models, accel_type)
_compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand All @@ -896,5 +844,17 @@ def concat_func(*inputs):
_compare_tvm_with_tflite(concat_func, shapes, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_tflite_sigmoid(accel_type):
ifm_shape = [1, 135, 41, 6]

@tf.function
def sigmoid_function(x):
op = tf.nn.sigmoid(x)
return op

_compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
Loading