diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index deb900d52d09..64b2dc20981d 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -106,6 +106,25 @@ struct DequantizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attribute for broadcast operator */ +struct BroadcastAttrs : public tvm::AttrsNode { + int lhs_axis; + int rhs_axis; + + TVM_DECLARE_ATTRS(BroadcastAttrs, "relay.attrs.BroadcastAttrs") { + TVM_ATTR_FIELD(lhs_axis) + .describe( + "The channel axis for channel wise broadcast. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rhs_axis) + .describe( + "The channel axis for channel wise broadcast. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + } +}; + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 7799060816a3..8b92fdf2672d 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -494,6 +494,11 @@ class OneHotAttrs(Attrs): """Attributes used in one_hot operators""" +@tvm._ffi.register_object("relay.attrs.BroadcastAttrs") +class BroadcastAttrs(Attrs): + """Attributes used in broadcast operators""" + + @tvm._ffi.register_object("relay.attrs.QuantizeAttrs") class QuantizeAttrs(Attrs): """Attributes used in quantize operators""" diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index ab2675004868..10c2df68d4ee 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -593,7 +593,16 @@ def conv2d_transpose( def add( - lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + lhs_axis=-1, + rhs_axis=-1, ): """Quantized addition with numpy-style broadcasting. @@ -623,6 +632,14 @@ def add( output_zero_point: relay.Expr The zero point of output quantized expr. + lhs_axis: int + The channel axis for lhs quantization. Default value is -1 which corresponds + to the last axis. + + rhs_axis: int + The channel axis for rhs quantization. Default value is -1 which corresponds + to the last axis. + Returns ------- result : relay.Expr @@ -638,6 +655,8 @@ def add( rhs_zero_point, output_scale, output_zero_point, + lhs_axis, + rhs_axis, ) @@ -702,7 +721,16 @@ def dense( def mul( - lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + lhs_axis=-1, + rhs_axis=-1, ): """Quantized multiplication with numpy-style broadcasting. @@ -732,6 +760,14 @@ def mul( output_zero_point: relay.Expr The zero point of output quantized expr. + lhs_axis: int + The channel axis for lhs quantization. Default value is -1 which corresponds + to the last axis. + + rhs_axis: int + The channel axis for rhs quantization. Default value is -1 which corresponds + to the last axis. + Returns ------- result : relay.Expr @@ -747,6 +783,8 @@ def mul( rhs_zero_point, output_scale, output_zero_point, + lhs_axis, + rhs_axis, ) @@ -961,7 +999,16 @@ def sigmoid(x, scale, zero_point, output_scale, output_zero_point): def subtract( - lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + lhs_axis=-1, + rhs_axis=-1, ): """Quantized subtraction with numpy-style broadcasting. @@ -991,6 +1038,14 @@ def subtract( output_zero_point: relay.Expr The zero point of output quantized expr. + lhs_axis: int + The channel axis for lhs quantization. Default value is -1 which corresponds + to the last axis. + + rhs_axis: int + The channel axis for rhs quantization. Default value is -1 which corresponds + to the last axis. + Returns ------- result : relay.Expr @@ -1006,6 +1061,8 @@ def subtract( rhs_zero_point, output_scale, output_zero_point, + lhs_axis, + rhs_axis, ) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 4cd200611115..38af8911bc53 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -451,6 +451,8 @@ def binary(expr, type_map): right_t.zero_point, out_t.scale, out_t.zero_point, + left_t.axis, + right_t.axis, ) return [out, out_t] diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index b0dc3e4af5c4..56e97674def4 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -45,6 +45,12 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); + const auto* broadcast_attrs = attrs.as(); + ICHECK(broadcast_attrs != nullptr); + + auto lhs_axis = broadcast_attrs->lhs_axis; + auto rhs_axis = broadcast_attrs->rhs_axis; + // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first add the tensors, subtract the zero point, and requantize at the end. @@ -68,11 +74,11 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Requantize LHS if necessary. Computes Q_a' auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, - args.output_zero_point, input_type.shape); + args.output_zero_point, input_type.shape, lhs_axis); // Requantize RHS if necessary. Computes Q_b' auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, - args.output_zero_point, input_type.shape); + args.output_zero_point, input_type.shape, rhs_axis); // Computes Q_a' + Q_b' auto output = Add(requantized_lhs, requantized_rhs); diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index 781114cc5f5a..87ee7d2f1f4d 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -42,6 +42,8 @@ namespace qnn { */ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { + Expr output; + // Get the attrs. QnnBinaryOpArguments args(new_args); @@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, const auto int32_dtype = DataType::Int(32); const auto float32_dtype = DataType::Float(32); - /* - A tensor multiplication c = a * b can be written in terms of respective - quantized tensors, scales and zero points as - S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b). - - We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different - quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' = - 0. The quantized multiplication then becomes - Q_c = S'/S_c Q' + z_c, - which is essentially a requantization of tensor Q' into tensor Q_c. - */ - - auto lhs_shifted = Cast(args.lhs, int32_dtype); - auto rhs_shifted = Cast(args.rhs, int32_dtype); - - auto zero_scalar = MakeConstantScalar(int32_dtype, 0); - if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) { - lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point); + const auto* broadcast_attrs = attrs.as(); + ICHECK(broadcast_attrs != nullptr); + + auto lhs_axis = broadcast_attrs->lhs_axis; + auto rhs_axis = broadcast_attrs->rhs_axis; + + if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) { + /* + This is per-tensor quantized multiply. + + A tensor multiplication c = a * b can be written in terms of respective + quantized tensors, scales and zero points as + S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b). + + We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different + quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' = + 0. The quantized multiplication then becomes + Q_c = S'/S_c Q' + z_c, + which is essentially a requantization of tensor Q' into tensor Q_c. + */ + + auto lhs_shifted = Cast(args.lhs, int32_dtype); + auto rhs_shifted = Cast(args.rhs, int32_dtype); + + auto zero_scalar = MakeConstantScalar(int32_dtype, 0); + if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) { + lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point); + } + + if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) { + rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point); + } + + // Create a new tensor Q' + output = Multiply(lhs_shifted, rhs_shifted); + + // Get the adjusted new scale and zero points. + float lhs_scale_float = GetScalarFromConstant(args.lhs_scale); + float rhs_scale_float = GetScalarFromConstant(args.rhs_scale); + float new_scale_float = lhs_scale_float * rhs_scale_float; + auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float); + auto new_input_zero_point = zero_scalar; + + // Requantize to get Q_c + output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, + args.output_scale, args.output_zero_point, input_type.dtype); + } else if (lhs_axis == rhs_axis) { + /* + This is per-channel quantized multiply, assumming lhs_axis and rhs_axis are the same. + The subtract is done on the specified axis via broadcast. Then, we multiply lhs and rhs. + The output is requantized using new scale and axis. TODO: support different axes. + */ + + auto lhs_data = Cast(args.lhs, int32_dtype); + auto rhs_data = Cast(args.rhs, int32_dtype); + + auto zero_scalar = MakeConstantScalar(int32_dtype, 0); + if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) { + // Broadcast lhs zero point if needed + int rank = static_cast(input_type.shape.size()); + int axis = (lhs_axis < 0) ? ((rank > 0) ? rank + lhs_axis : 0) : lhs_axis; + Expr lhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.lhs_zero_point, + { + -1, + }), + rank, {axis}); + lhs_data = Subtract(lhs_data, Cast(lhs_zero_broadcast, DataType::Int(32))); + } + + if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) { + // Broadcast rhs zero point if needed + int rank = static_cast(input_type.shape.size()); + int axis = (rhs_axis < 0) ? ((rank > 0) ? rank + rhs_axis : 0) : rhs_axis; + Expr rhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.rhs_zero_point, + { + -1, + }), + rank, {axis}); + rhs_data = Subtract(rhs_data, Cast(rhs_zero_broadcast, DataType::Int(32))); + } + + // Create a new tensor Q' + output = Multiply(lhs_data, rhs_data); + + // Requantize to get Q_c + auto lhs_scales = GetFloatVectorFromConstant(args.lhs_scale); + auto rhs_scales = GetFloatVectorFromConstant(args.rhs_scale); + std::vector output_multipliers; + for (size_t i = 0; i < lhs_scales.size(); i++) { + double multiplier = static_cast(lhs_scales[i]) * static_cast(rhs_scales[i]); + output_multipliers.push_back(multiplier); + } + auto new_input_scale = MakeConstantTensor( + DataType::Float(32), {(int64_t)output_multipliers.size()}, output_multipliers); + + output = Requantize(output, input_type.shape, new_input_scale, zero_scalar, args.output_scale, + args.output_zero_point, input_type.dtype, lhs_axis); + + } else { + LOG(FATAL) << "Not supported: lhs_axis and rhs_axis are not the same."; } - if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) { - rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point); - } - - // Create a new tensor Q' - auto output = Multiply(lhs_shifted, rhs_shifted); - - // Get the adjusted new scale and zero points. - float lhs_scale_float = GetScalarFromConstant(args.lhs_scale); - float rhs_scale_float = GetScalarFromConstant(args.rhs_scale); - float new_scale_float = lhs_scale_float * rhs_scale_float; - auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float); - auto new_input_zero_point = zero_scalar; - - // Requantize to get Q_c - output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, - args.output_scale, args.output_zero_point, input_type.dtype); - return output; } diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index de97fb860b8a..6d1eb3a34386 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -40,6 +40,8 @@ namespace tvm { namespace relay { namespace qnn { +TVM_REGISTER_NODE_TYPE(BroadcastAttrs); + /* * Number of inputs for the Qnn binary operators. * Refer the QNN_REGISTER_BINARY_OP macro to see @@ -191,12 +193,13 @@ inline Expr ConvertDtype(const Expr& expr, const DataType& target_dtype) { inline Expr RequantizeOrUpcast(const Expr& expr, const Expr& expr_scale, const Expr& expr_zero_point, const Expr& target_scale, const Expr& target_zero_point, const Array& expr_shape, + const int& axis = -1, const DataType& target_dtype = DataType::Int(32)) { auto result = expr; if (!IsEqualScalar(expr_scale, target_scale) || !IsEqualScalar(expr_zero_point, target_zero_point)) { result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, target_scale, - target_zero_point, target_dtype); + target_zero_point, target_dtype, axis); } else { result = Cast(result, target_dtype); } @@ -243,10 +246,62 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con return false; } } - ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale - ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point - ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale - ICHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point + + const auto* lhs_data = types[0].as(); + const auto* rhs_data = types[1].as(); + + if (lhs_data == nullptr || rhs_data == nullptr) { + return false; + } + + const BroadcastAttrs* broadcast_attrs = attrs.as(); + int lhs_axis = broadcast_attrs->lhs_axis; + int rhs_axis = broadcast_attrs->rhs_axis; + + auto lhs_rank = static_cast(lhs_data->shape.size()); + auto rhs_rank = static_cast(rhs_data->shape.size()); + + lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : lhs_axis; + rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) : rhs_axis; + + // If zero point and scale are scalar then axis doesn't matter. + bool lhs_scale_is_scalar = (types[2].as())->shape.size() == 0; + bool lhs_zp_is_scalar = (types[3].as())->shape.size() == 0; + bool rhs_scale_is_scalar = (types[4].as())->shape.size() == 0; + bool rhs_zp_is_scalar = (types[5].as())->shape.size() == 0; + + if (!(lhs_scale_is_scalar && lhs_zp_is_scalar)) { + ICHECK_LT(lhs_axis, lhs_rank > 0 ? lhs_rank : 1) + << "lhs_axis " << broadcast_attrs->lhs_axis << " is out of range"; + ICHECK_GE(lhs_axis, 0) << "lhs_axis " << broadcast_attrs->lhs_axis << " is out of range"; + } + + if (!(rhs_scale_is_scalar && rhs_zp_is_scalar)) { + ICHECK_LT(rhs_axis, rhs_rank > 0 ? rhs_rank : 1) + << "rhs_axis " << broadcast_attrs->rhs_axis << " is out of range"; + ICHECK_GE(rhs_axis, 0) << "rhs_axis " << broadcast_attrs->rhs_axis << " is out of range"; + } + + PrimExpr lhs_axis_shape; + if (lhs_rank > 0) { + lhs_axis_shape = lhs_data->shape[lhs_axis]; + } else { + lhs_axis_shape = Integer(1); + } + + PrimExpr rhs_axis_shape; + if (rhs_rank > 0) { + rhs_axis_shape = rhs_data->shape[rhs_axis]; + } else { + rhs_axis_shape = Integer(1); + } + + // Check and assign types for scale and zero points. + AssignType(types[2], DataType::Float(32), lhs_axis_shape, reporter); // lhs_scale + AssignType(types[3], DataType::Int(32), lhs_axis_shape, reporter); // lhs_zero_point + AssignType(types[4], DataType::Float(32), rhs_axis_shape, reporter); // rhs_scale + AssignType(types[5], DataType::Int(32), rhs_axis_shape, reporter); // rhs_zero_point + ICHECK(IsScalarType(types[6], DataType::Float(32))); // output_scale ICHECK(IsScalarType(types[7], DataType::Int(32))); // output_zero_point @@ -269,14 +324,19 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con #define QNN_REGISTER_BINARY_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point, \ + int lhs_axis, int rhs_axis) { \ static const Op& op = Op::Get("qnn." OpName); \ + auto attrs = make_object(); \ + attrs->lhs_axis = lhs_axis; \ + attrs->rhs_axis = rhs_axis; \ return Call(op, \ {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ output_zero_point}, \ - Attrs(), {}); \ + Attrs(attrs), {}); \ }); \ RELAY_REGISTER_OP("qnn." OpName) \ + .set_attrs_type() \ .set_num_inputs(kNumQnnBinaryOpInputs) \ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ @@ -286,6 +346,8 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ + .add_argument("lhs_axis", "Tensor", "The channel quantization of the lhs tensor.") \ + .add_argument("rhs_axis", "Tensor", "The channel quantization of the rhs tensor.") \ .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ .set_attr("TNonComputational", true) \ .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc index b928bd5e465c..1ec3c7a6531c 100644 --- a/src/relay/qnn/op/subtract.cc +++ b/src/relay/qnn/op/subtract.cc @@ -45,6 +45,12 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); + const auto* broadcast_attrs = attrs.as(); + ICHECK(broadcast_attrs != nullptr); + + auto lhs_axis = broadcast_attrs->lhs_axis; + auto rhs_axis = broadcast_attrs->rhs_axis; + // TODO(shoubhik) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first subtract the tensors, add the zero point, and requantize at the end. @@ -68,11 +74,11 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, // Requantize LHS if necessary. Computes Q_a' auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, - args.output_zero_point, input_type.shape); + args.output_zero_point, input_type.shape, lhs_axis); // Requantize RHS if necessary. Computes Q_b' auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, - args.output_zero_point, input_type.shape); + args.output_zero_point, input_type.shape, rhs_axis); // Computes Q_a' - Q_b' auto output = Subtract(requantized_lhs, requantized_rhs); diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index d7769707f01e..b4841c8ddda8 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -106,9 +106,11 @@ std::string SelectRequntizeParameter(const std::string& arg_value, const std::st static inline Expr Requantize(const Expr& data, const Array& input_shape, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, - const DataType& out_dtype, const std::string& rounding = "None", + const DataType& out_dtype, const int& axis = -1, + const std::string& rounding = "None", const std::string& compute_dtype = "None") { auto attrs = make_object(); + attrs->axis = axis; attrs->out_dtype = std::move(out_dtype); const RequantizeConfig& cfg = RequantizeConfig::Current(); attrs->rounding = diff --git a/tests/python/contrib/test_arm_compute_lib/test_add.py b/tests/python/contrib/test_arm_compute_lib/test_add.py index d7abc5c414fb..ba324358f8e5 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_add.py +++ b/tests/python/contrib/test_arm_compute_lib/test_add.py @@ -74,6 +74,10 @@ def _get_expected_codegen(shape, dtype, op_name, qnn_params): }, } + if qnn_params: + node["attrs"]["lhs_axis"] = [["-1"]] + node["attrs"]["rhs_axis"] = [["-1"]] + return [*inputs, node] diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 1ac5674b48d5..602671af41ac 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -600,13 +600,106 @@ def test_fake_quantize_binary(operator): compare_fq_to_int(op, [x_np, y_np]) +@pytest.mark.parametrize( + "operator", + [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], +) +def test_fake_quantize_binary_per_channel(operator): + def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp, out_zp, lhs_axis, rhs_axis): + if operator == relay.op.multiply: + out_scale = relay.const(2.0) + rhs_axis = lhs_axis # TODO: Support different axes for per-channel quantized multiply + else: + out_scale = relay.const(0.1) + + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(lhs_scale), relay.const(lhs_zp), axis=lhs_axis) + + y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8") + y = relay.qnn.op.dequantize(y, relay.const(rhs_scale), relay.const(rhs_zp), axis=rhs_axis) + + op = operator(x, y) + + op = relay.qnn.op.quantize(op, out_scale, relay.const(out_zp), out_dtype="int8") + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np, y_np], allow_rounding_error=True) + + # Same axis + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 3), + rhs_scale=np.random.uniform(1.0, 5.0, 3), + lhs_zp=0, + rhs_zp=0, + out_zp=0, + lhs_axis=1, + rhs_axis=1, + ) + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 3), + rhs_scale=np.random.uniform(1.0, 5.0, 3), + lhs_zp=np.random.randint(1, 3), + rhs_zp=np.random.randint(1, 3), + out_zp=0, + lhs_axis=1, + rhs_axis=1, + ) + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 3), + rhs_scale=np.random.uniform(1.0, 5.0, 3), + lhs_zp=np.random.randint(1, 3), + rhs_zp=np.random.randint(1, 3), + out_zp=np.random.randint(1, 3), + lhs_axis=1, + rhs_axis=1, + ) + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 224), + rhs_scale=np.random.uniform(1.0, 5.0, 224), + lhs_zp=np.random.randint(1, 3), + rhs_zp=np.random.randint(1, 3), + out_zp=np.random.randint(1, 3), + lhs_axis=-1, + rhs_axis=-1, + ) + + # Different axes + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 224), + rhs_scale=np.random.uniform(1.0, 5.0, 224), + lhs_zp=0, + rhs_zp=0, + out_zp=0, + lhs_axis=2, + rhs_axis=3, + ) + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 224), + rhs_scale=np.random.uniform(1.0, 5.0, 224), + lhs_zp=np.random.randint(1, 3), + rhs_zp=np.random.randint(1, 3), + out_zp=0, + lhs_axis=2, + rhs_axis=3, + ) + verify_binary_per_channel( + lhs_scale=np.random.uniform(1.0, 5.0, 224), + rhs_scale=np.random.uniform(1.0, 5.0, 224), + lhs_zp=np.random.randint(1, 3), + rhs_zp=np.random.randint(1, 3), + out_zp=np.random.randint(1, 3), + lhs_axis=2, + rhs_axis=3, + ) + + @pytest.mark.parametrize( "operator", [ relay.op.add, relay.op.multiply, relay.op.subtract, - relay.op.subtract, relay.op.minimum, relay.op.maximum, ],