Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,11 @@ def _impl_int8(inputs, _):
alpha = inputs[1]
output_scale = _expr.const(inputs[3])
output_zero_point = _expr.const(inputs[4])
return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, output_zero_point)
input_scale = _expr.const(inputs[5])
input_zero_point = _expr.const(inputs[6])
return relay.qnn.op.leaky_relu(
inputs[0], alpha, input_scale, input_zero_point, output_scale, output_zero_point
)

def _impl(inputs, _):
assert len(inputs) == 7, "Input quant params not found in op inputs"
Expand Down
21 changes: 13 additions & 8 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)


def leaky_relu(x, alpha, scale, zero_point):
def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zero_point):
"""Quantized leaky relu.

Parameters
Expand All @@ -1188,11 +1188,14 @@ def leaky_relu(x, alpha, scale, zero_point):
The quantized input tensor.
alpha: double
The alpha value.
scale: relay.Expr
The scale of the quantized expr.
zero_point: relay.Expr
The zero point of quantized expr.

input_scale: relay.Expr
The scale of the input quantized expr.
input_zero_point: relay.Expr
The zero point of input quantized expr.
output_scale: relay.Expr
The scale of the output quantized expr.
output_zero_point: relay.Expr
The zero point of output quantized expr.
Returns
-------
result : relay.Expr
Expand All @@ -1201,6 +1204,8 @@ def leaky_relu(x, alpha, scale, zero_point):
return _make.leaky_relu(
x,
alpha,
scale,
zero_point,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
)
9 changes: 6 additions & 3 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,13 @@ def relu(expr, type_map):
def leaky_relu(expr, type_map):
"""Rewrite a leaky relu op"""
arg = expr.args[0]
t = type_map[arg]
x_t = type_map[arg]
out_t = type_map[expr]
alpha = expr.attrs.alpha
output = relay.qnn.op.leaky_relu(expr, alpha, t.scale, t.zero_point)
return [output, t]
output = relay.qnn.op.leaky_relu(
expr, alpha, x_t.scale, x_t.zero_point, out_t.scale, out_t.zero_point
)
return [output, x_t]


@register_fake_quantization_to_integer("nn.pad")
Expand Down
85 changes: 59 additions & 26 deletions src/relay/qnn/op/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace qnn {

bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, scale, zero_point
ICHECK_EQ(types.size(), 4);
// Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, out_type
ICHECK_EQ(types.size(), 6);
const auto* x = types[0].as<TensorTypeNode>();
if (x == nullptr) return false;
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
Expand All @@ -42,31 +42,37 @@ bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& attr
ICHECK(param != nullptr) << "LeakyReluAttrs cannot be nullptr.";

// Check the types of scale and zero points.
for (size_t i = 1; i < 3; ++i) {
for (size_t i = 1; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}

ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
ICHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

// Assign types for scale and zero points.
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // input_scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // input_zero_point
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// IdentityRel infer type function.
Array<Type> tensor_types = {types[0], types[3]};
Array<Type> tensor_types = {types[0], types[5]};
return IdentityRel(tensor_types, 2, attrs, reporter);
}

// Positional relay function to create quantized leaky relu operator used by frontend FFI.
Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr scale, Expr zero_point) {
Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr input_scale, Expr input_zero_point,
Expr output_scale, Expr output_zero_point) {
auto attrs = make_object<LeakyReluAttrs>();
attrs->alpha = alpha;
static const Op& op = Op::Get("qnn.leaky_relu");
return Call(op, {x, scale, zero_point}, Attrs(attrs), {});
return Call(op, {x, input_scale, input_zero_point, output_scale, output_zero_point}, Attrs(attrs),
{});
}

/*
Expand All @@ -82,42 +88,69 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// by a small alpha value < 1.
//
// We assume the same scale and zero point for alpha and the input tensor.
// Let T = s(q_t - z) where q_t is the input arg[0]
// Then, the quantized value of alpha * T is:
// q(a * T, s, z) = [(a * T) / s] + z = a * s(q_t - z) / s + z = a * (q_t - z) + z
// = a * q_t + (1 - a) * z
// LeakyReLU can be written in terms of respective quantized tensors, scales and
// zero points as
//
// We return the quantized value of alpha * T for all values q_t < input_zero_point.

ICHECK_EQ(new_args.size(), 3);
Expr quantized_data = Cast(new_args[0], DataType::Int(32));
// scale_o * (Q_o - zp_o) = alpha * scale_i * (Q_i - zp_i) when Q_i < zp_i (1)
// scale_o * (Q_o - zp_o) = scale_i * (Q_i - zp_i) when Q_i >= zp_i (2)
//
// Since the input qnn params can be different than output qnn params, we first requantize the
// input tensor to the output qnn params. After requantizing Q_i, equation (1) becames equation
// (3) where Q_i' is the requantized data from Q_i.
//
// scale_o * (Q_o - zp_o) = alpha * scale_o * (Q_i' - zp_o) when Q_i < zp_i (3)
// Q_o = alpha * Q_i' + (1 - alpha) * zp_o when Q_i < zp_i (4)
//
// It is equal to requantize Q_i to Q_o using scale_o and zp_o in equation (2).
// So equation (2) becomes
//
// Q_o = requantize(Q_i) when Q_i >= zp_i (5)
//
// Finnally, Q_o could be calculated by equation (4) and equation (5).
ICHECK_EQ(new_args.size(), 5);
Expr data = Cast(new_args[0], DataType::Int(32));
Expr input_scale = new_args[1];
Expr input_zero_point = Cast(new_args[2], DataType::Int(32));
Expr output_scale = new_args[3];
Expr output_zero_point = Cast(new_args[4], DataType::Int(32));

const auto* q_attrs = attrs.as<LeakyReluAttrs>();
auto alpha = q_attrs->alpha;

const auto input_shape = get_shape(arg_types[0]);
const auto input_dtype = arg_types[0].as<TensorTypeNode>()->dtype;

// requantize the input to Q_i'
auto requantized_expr = RequantizeOrUpcast(data, input_scale, input_zero_point, output_scale,
output_zero_point, input_shape);

// alpha * Q_i'
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha);
auto prod = FixedPointMultiply(quantized_data, fixed_point_multiplier, shift);
auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, shift);

// (1 - alpha) * zp_o
int32_t fixed_point_multiplier_z, shift_z;
std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha);
auto scaled_z = FixedPointMultiply(input_zero_point, fixed_point_multiplier_z, shift_z);
auto scaled_z = FixedPointMultiply(output_zero_point, fixed_point_multiplier_z, shift_z);

// alpha * Q_i' + (1 - alpha) * zp_o
auto add = Add(prod, scaled_z);
auto output = Where(Less(quantized_data, input_zero_point), add, quantized_data);
auto output = Where(Less(data, input_zero_point), add, requantized_expr);

const auto* input_type = arg_types[0].as<TensorTypeNode>();
return ConvertDtype(output, input_type->dtype);
return ConvertDtype(output, input_dtype);
}

RELAY_REGISTER_OP("qnn.leaky_relu")
.describe("Leaky relu for quantized tensors.")
.set_attrs_type<LeakyReluAttrs>()
.set_num_inputs(3)
.set_num_inputs(5)
.add_argument("data", "Quantized Tensor", "The input data.")
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor",
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QLeakyRelu", QnnLeakyReluRel)
.set_attr<TNonComputational>("TNonComputational", true)
Expand Down
30 changes: 21 additions & 9 deletions tests/python/relay/test_op_qnn_leaky_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,36 @@ def dequantize(data, scale, zp):
return scale * (np.asarray(data) - zp)


def generate_golden_output(x_data, dequantized_x, alpha, scale, zero_point):
def generate_golden_output(x_data, dequantized_x, alpha, o_scale, o_zero_point, i_zero_point):
prod = np.multiply(dequantized_x, alpha)
prod = np.around(prod / scale + zero_point)
prod = np.around(prod / o_scale + o_zero_point)

output = np.where(x_data < zero_point, prod, x_data)
q_min = np.iinfo(np.uint8).min
q_max = np.iinfo(np.uint8).max
prod = np.clip(prod, q_min, q_max)

requantized = np.clip(np.round(dequantized_x / o_scale + o_zero_point), q_min, q_max)

output = np.where(x_data < i_zero_point, prod, requantized)
return output


def test_qnn_leaky_relu():
data_dtype = "uint8"
scale = 0.125
zero_point = 60
input_scale = 0.125
input_zero_point = 60
output_scale = 0.6
output_zero_point = 17
alpha = 0.9

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.qnn.op.leaky_relu(
x=x,
alpha=alpha,
scale=relay.const(scale, "float32"),
zero_point=relay.const(zero_point, "int32"),
input_scale=relay.const(input_scale, "float32"),
input_zero_point=relay.const(input_zero_point, "int32"),
output_scale=relay.const(output_scale, "float32"),
output_zero_point=relay.const(output_zero_point, "int32"),
)

func = relay.Function([x], y)
Expand All @@ -53,8 +63,10 @@ def test_qnn_leaky_relu():
func = mod["main"]

x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
x_dequantized = dequantize(x_data, scale, zero_point)
golden_output = generate_golden_output(x_data, x_dequantized, alpha, scale, zero_point)
x_dequantized = dequantize(x_data, input_scale, input_zero_point)
golden_output = generate_golden_output(
x_data, x_dequantized, alpha, output_scale, output_zero_point, input_zero_point
)

op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)

Expand Down