diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index e261f129bf50..fdd465529123 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1355,6 +1355,62 @@ def callback(self, pre, post, node_map): return ethosu_fc +class PadRewriter(DFPatternCallback): + """Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d + operator""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.PadParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.PadParams(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + w_h, w_w = (1, 1) + # OHWI format for the ethosu_depthwise_conv2d kernel weights + weight_shape = (params.ifm.shape[-1], w_h, w_w, 1) + weights = relay.const(np.full(weight_shape, 1), params.ifm.dtype) + scale_bias = vela_api.pack_biases( + biases=np.zeros(params.ifm.shape[-1]), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=np.array(1.0, dtype=np.float32), + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=False, + ) + + return ethosu_ops.ethosu_depthwise_conv2d( + ifm=post.args[0], + weight=weights, + scale_bias=relay.const(scale_bias, "uint8"), + lut=relay.const([], "int8"), + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point.item()), + weight_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point.item()), + kernel_shape=(w_h, w_w), + ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=(1, 1), + padding=params.padding, + dilation=(1, 1), + activation="NONE", + clip_min=0, + clip_max=0, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ofm_dtype=str(params.ofm.dtype), + ) + + @util.create_npu_function_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -1375,6 +1431,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: FullyConnectedRewriter(), MaxPoolingRewriter(), AvgPoolingRewriter(), + PadRewriter(), AddRewriter(), SubRewriter(), MulRewriter(), diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index de4c50e51c63..70ec1c12eb3d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -143,6 +143,16 @@ class QDenseArgs(Enum): WEIGHTS_SCALE = 5 +class QPad2DArgs(Enum): + """ + This is a helper enum to obtain the correct index + of nn.pad arguments. + """ + + IFM = 0 + IFM_ZERO_POINT = 1 + + def is_npu_func(func: relay.Function) -> bool: """Check if the given function is an NPU function.""" return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index c0f8e5e9708e..a86357db39fc 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1772,6 +1772,86 @@ def hard_swish_pattern(): return quantize +class PadParams: + """ + This class will parse a call to a ethosu.pad2d composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.pad2d" + # The ethos-u.pad2d composite function will be transformed to the + # ethosu_depthwise_conv2d operator. + # For the ethosu_depthwise_conv2d the hardware only supports padding + # upto the numbers as follows, so we define such padding limits + padding_bounds = [31, 31, 32, 32] + + def __init__(self, func_body: Call): + from tvm.relay.backend.contrib.ethosu.util import QPad2DArgs + + # there is no 'layout' attribute in nn.pad + layout = "NHWC" + self.ifm = TensorParams( + tensor=func_body.args[QPad2DArgs.IFM.value], + layout=layout, + scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), + zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value], + ) + + self.padding = self.extract_padding(func_body) + self.ofm = TensorParams( + tensor=func_body, + layout=layout, + scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))), + zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value], + ) + + @staticmethod + def extract_padding( + padding: relay.Call, + ) -> Optional[Tuple[int, int, int, int]]: + """ + Here we check whether a separate padding operation can be rewritten + as NPU depthwise convolution. If the padding specified by the + separate nn.pad operation is not supported by NPU depthwise convolution, + None will be returned. This will cause the nn.pad not to be offloaded to NPU. + """ + pad_width = padding.attrs["pad_width"] + if len(pad_width) != 4: + return None + if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]: + return None + return [ + pad_width[1][0], + pad_width[2][0], + pad_width[1][1], + pad_width[2][1], + ] + + def is_valid(self): + """ + This function checks whether pad has compatible attributes + with the NPU depthwise convolution + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_batch_size(self.ifm): + return False + if not self.padding or not check_padding(self.padding, self.padding_bounds): + return False + if not check_dimensions(self.ifm) or not check_dimensions(self.ofm): + return False + return True + + +def pad_pattern(): + """Create pattern for pad""" + pattern = is_op("nn.pad")(wildcard(), is_constant()) + return pattern + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1805,6 +1885,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_avgpool2d_pattern(), lambda pat: AvgPool2DParams(pat).is_valid(), ), + ( + PadParams.composite_name, + pad_pattern(), + lambda pat: PadParams(pat).is_valid(), + ), ( AddParams.composite_name, qnn_add_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index e06e36638d7f..13b54b988963 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -258,6 +258,29 @@ def depthwise_conv2d(x): infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256") +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) +@pytest.mark.parametrize("padding", [(0, 1, 0, 0), (1, 1, 1, 1), (1, 1, 5, 5)]) +@pytest.mark.parametrize("const_value", [0, 5, 125, -5]) +def test_tflite_separate_pad( + ifm_shape, + padding, + const_value, +): + + np.random.seed(0) + + @tf.function + def pad2d(x): + return tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + const_value, + ) + + infra.compare_tvm_with_tflite(pad2d, [ifm_shape], "ethos-u55-256") + + @pytest.mark.parametrize( "accel_type", ACCEL_TYPES,