Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,92 @@ def __call__(self, *args, **kwargs):
pass


class FullyConnectedRewriter(DFPatternCallback):
"""Legalize Fully Connected (with bias and clip) to an NPU operator"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.FullyConnectedParams.composite_name})
)(wildcard())

def callback(self, pre, post, node_map):
params = ethosu_patterns.FullyConnectedParams(post.op.body)
params.ifm.tensor = post.args[0]

# IFM reshapes
ifm = post.args[0]
if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1:
ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1]))

# Weight transformations
weights_values = params.weights.values
weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2))
if params.activation:
activation = "CLIP"
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0
bias_values = (
params.biases.tensor.data.asnumpy()
if params.biases
else np.zeros((params.ofm.shape[-1]))
)
scale_bias = vela_api.pack_biases(
biases=bias_values,
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=params.weights.q_params.scale_f32,
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)
ethosu_fc = ethosu_ops.ethosu_conv2d(
ifm=ifm,
weight=relay.const(weights_values_ohwi, params.weights.values.dtype),
scale_bias=relay.const(scale_bias, "uint8"),
lut=relay.const([], dtype="int8"),
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
weight_zero_point=int(params.weights.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
kernel_shape=[1, 1],
ofm_channels=params.weights.shape[0],
strides=(1, 1),
padding=(0, 0, 0, 0),
dilation=(1, 1),
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
upscale="NONE",
ifm_layout="NHWC",
ofm_layout="NHWC",
)

if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == params.ofm.shape[2] == 1:
ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape)
Comment on lines +1652 to +1653
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect there isn't a test case that exercises this case since on line 1700 this pass runs after the no op legalizer, so the last reshape won't have a following identity op and will fall over in TE

return ethosu_fc


@ir.transform.module_pass(opt_level=1)
class LegalizeFullyConnected:
"""This is the pass that wraps the FullyConnectedRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(FullyConnectedRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1621,6 +1707,7 @@ def transform_module(
mod = LegalizeSqueeze()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeFullyConnected()(mod)
mod = LegalizeNoOps()(mod)
return mod

Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ class DequantizeArgs(Enum):
IFM_ZERO_POINT = 2


class QDenseArgs(Enum):
"""
This is a helper enum to access the correct index of
qnn.dense arguments
"""

IFM = 0
WEIGHTS = 1
IFM_ZERO_POINT = 2
WEIGHTS_ZERO_POINT = 3
IFM_SCALE = 4
WEIGHTS_SCALE = 5


def is_composite_func(func: relay.Function, name: str) -> bool:
"""
This method checks whether the call is to
Expand Down
109 changes: 109 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,110 @@ def squeeze_pattern():
return is_op("squeeze")(wildcard())


class FullyConnectedParams:
"""
This class will parse a call to an ethos-u.fully_connected composite
function and extract the parameter information.
"""

composite_name = "ethos-u.fully_connected"

@requires_vela
def __init__(self, func_body):
from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

self.activation = None
if str(func_body.op) == "clip":
self.activation = func_body
requantize_op = self.activation.args[0]
else:
requantize_op = func_body

call = requantize_op.args[0]
if str(requantize_op.args[0].op) == "nn.bias_add":
bias_add = call
qnn_dense = call.args[0]
else:
bias_add = None
qnn_dense = call

# weights & biases are params as they should be constant
self.weights = TensorParams(
qnn_dense.args[QDenseArgs.WEIGHTS.value],
None,
qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value],
qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value],
)
self.biases = (
TensorParams(
bias_add.args[BiasAddArgs.BIASES.value],
None,
requantize_op.args[RequantArgs.IFM_SCALE.value],
requantize_op.args[RequantArgs.IFM_ZERO_POINT.value],
)
if bias_add
else None
)
self.ifm = TensorParams(
qnn_dense.args[QDenseArgs.IFM.value],
None,
qnn_dense.args[QDenseArgs.IFM_SCALE.value],
qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
func_body,
None,
requantize_op.args[RequantArgs.OFM_SCALE.value],
requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
)

def is_valid(self) -> bool:
"""
Checks whether Fully Connected has compatible attributes with HW
"""

def check_weights_fc(weights):
"""Checks whether weight tensor is compatible with HW"""
weights_limit = 127 * 65536
# A saturation upper bound check for accumulators
weights.values = weights.values - weights.q_params.zero_point
axis = 1
sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis))
if not sum_weights <= weights_limit:
return False
return True

if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
if not check_weights_fc(self.weights):
return False
if not check_bias(self.biases):
return False
if not check_batch_size(self.ifm):
return False
# Check input shape
if not len(self.ifm.shape) == 2:
return False
# Check output shape
if not len(self.ofm.shape) == 2:
return False
return True


def qnn_fc_pattern():
dense = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
optional_bias_add = is_op("nn.bias_add")(dense, is_constant())
req = is_op("qnn.requantize")(
dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
optional_clip = req.optional(is_op("clip"))
return optional_clip


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand All @@ -1555,6 +1659,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_conv2d_transpose_pattern(),
lambda pat: QnnConv2DTransposeParams(pat).is_valid(),
),
(
FullyConnectedParams.composite_name,
qnn_fc_pattern(),
lambda pat: FullyConnectedParams(pat).is_valid(),
),
(
MaxPool2DParams.composite_name,
qnn_maxpool2d_pattern(),
Expand Down
30 changes: 30 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,5 +1182,35 @@ def leaky_relu_func(x):
_compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
def test_tflite_fully_connected(
accel_type,
ifm_shape,
ofm_channels,
use_bias,
activation_function,
):
@tf.function
def fully_connected(x):
bias_shape = ofm_channels
bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32)
w = tf.constant(
np.random.uniform(size=[ifm_shape[1], ofm_channels]),
dtype=tf.float32,
)
x = tf.matmul(x, w)
if use_bias:
x = tf.nn.bias_add(x, bias)
if activation_function:
x = tf.nn.relu(x)
return x

_compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
103 changes: 103 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,5 +2421,108 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
def test_tflite_fully_connected(
ifm_shape,
ofm_channels,
use_bias,
activation_function,
):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def fully_connected(self, x):
bias_shape = ofm_channels
bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32)
w = tf.constant(
np.random.uniform(size=[ifm_shape[1], ofm_channels]),
dtype=tf.float32,
)
x = tf.matmul(x, w)
if use_bias:
x = tf.nn.bias_add(x, bias)
if activation_function:
x = tf.nn.relu(x)
return x

model = Model()
concrete_func = model.fully_connected.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32)
)
# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

def verify(ext_func):
op = ext_func.body.args[0]
ofm_channels = op.attrs.ofm_channels

# check IFM
ifm = op.args[0].checked_type
assert list(ifm.shape) == [1, 1] + list(ifm_shape)
assert str(ifm.dtype) == dtype

# check OFM
ofm = op.checked_type
assert list(ofm.shape) == [1, 1, 1, ofm_channels]
assert str(ofm.dtype) == dtype

# check weights
weights_ohwi = op.args[1].data.asnumpy()
assert str(weights_ohwi.dtype) == dtype
assert list(weights_ohwi.shape) == [ofm_channels, 1, 1, ifm_shape[1]]

# Check that scale_bias matches weight tensor
assert list(op.args[2].checked_type.shape)[0] == ofm_channels

assert list(op.attrs.padding) == [0, 0, 0, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: might also be worth checking the op name is an NPU convolution here as well

assert list(op.attrs.strides) == [1, 1]
assert list(op.attrs.dilation) == [1, 1]
if activation_function == "RELU":
assert str(op.attrs.activation) == "CLIP"

fc_pattern_table = [
(
ethosu.FullyConnectedParams.composite_name,
ethosu.qnn_fc_pattern(),
lambda pat: ethosu.FullyConnectedParams(pat).is_valid(),
)
]

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, fc_params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)

mod["main"] = bind_params_by_name(mod["main"], fc_params)
mod = partition_ethosu_by_table(mod, fc_pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)

verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
pytest.main([__file__])