Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
}
};

/*! \brief Attribute for broadcast operator */
struct BroadcastAttrs : public tvm::AttrsNode<BroadcastAttrs> {
int lhs_axis;
int rhs_axis;

TVM_DECLARE_ATTRS(BroadcastAttrs, "relay.attrs.BroadcastAttrs") {
TVM_ATTR_FIELD(lhs_axis)
.describe(
"The channel axis for channel wise broadcast. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(rhs_axis)
.describe(
"The channel axis for channel wise broadcast. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ class OneHotAttrs(Attrs):
"""Attributes used in one_hot operators"""


@tvm._ffi.register_object("relay.attrs.BroadcastAttrs")
class BroadcastAttrs(Attrs):
"""Attributes used in broadcast operators"""


@tvm._ffi.register_object("relay.attrs.QuantizeAttrs")
class QuantizeAttrs(Attrs):
"""Attributes used in quantize operators"""
Expand Down
63 changes: 60 additions & 3 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,16 @@ def conv2d_transpose(


def add(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis=-1,
rhs_axis=-1,
):
"""Quantized addition with numpy-style broadcasting.

Expand Down Expand Up @@ -623,6 +632,14 @@ def add(
output_zero_point: relay.Expr
The zero point of output quantized expr.

lhs_axis: int
The channel axis for lhs quantization. Default value is -1 which corresponds
to the last axis.

rhs_axis: int
The channel axis for rhs quantization. Default value is -1 which corresponds
to the last axis.

Returns
-------
result : relay.Expr
Expand All @@ -638,6 +655,8 @@ def add(
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis,
rhs_axis,
)


Expand Down Expand Up @@ -702,7 +721,16 @@ def dense(


def mul(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis=-1,
rhs_axis=-1,
):
"""Quantized multiplication with numpy-style broadcasting.

Expand Down Expand Up @@ -732,6 +760,14 @@ def mul(
output_zero_point: relay.Expr
The zero point of output quantized expr.

lhs_axis: int
The channel axis for lhs quantization. Default value is -1 which corresponds
to the last axis.

rhs_axis: int
The channel axis for rhs quantization. Default value is -1 which corresponds
to the last axis.

Returns
-------
result : relay.Expr
Expand All @@ -747,6 +783,8 @@ def mul(
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis,
rhs_axis,
)


Expand Down Expand Up @@ -961,7 +999,16 @@ def sigmoid(x, scale, zero_point, output_scale, output_zero_point):


def subtract(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis=-1,
rhs_axis=-1,
):
"""Quantized subtraction with numpy-style broadcasting.

Expand Down Expand Up @@ -991,6 +1038,14 @@ def subtract(
output_zero_point: relay.Expr
The zero point of output quantized expr.

lhs_axis: int
The channel axis for lhs quantization. Default value is -1 which corresponds
to the last axis.

rhs_axis: int
The channel axis for rhs quantization. Default value is -1 which corresponds
to the last axis.

Returns
-------
result : relay.Expr
Expand All @@ -1006,6 +1061,8 @@ def subtract(
rhs_zero_point,
output_scale,
output_zero_point,
lhs_axis,
rhs_axis,
)


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ def binary(expr, type_map):
right_t.zero_point,
out_t.scale,
out_t.zero_point,
left_t.axis,
right_t.axis,
)

return [out, out_t]
Expand Down
10 changes: 8 additions & 2 deletions src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);

const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
ICHECK(broadcast_attrs != nullptr);

auto lhs_axis = broadcast_attrs->lhs_axis;
auto rhs_axis = broadcast_attrs->rhs_axis;

// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if both input tensors have same qnn params. In
// that case, we can first add the tensors, subtract the zero point, and requantize at the end.
Expand All @@ -68,11 +74,11 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Requantize LHS if necessary. Computes Q_a'
auto requantized_lhs =
RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale,
args.output_zero_point, input_type.shape);
args.output_zero_point, input_type.shape, lhs_axis);
// Requantize RHS if necessary. Computes Q_b'
auto requantized_rhs =
RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale,
args.output_zero_point, input_type.shape);
args.output_zero_point, input_type.shape, rhs_axis);
// Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs);

Expand Down
138 changes: 102 additions & 36 deletions src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace qnn {
*/
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
Expr output;

// Get the attrs.
QnnBinaryOpArguments args(new_args);

Expand All @@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto int32_dtype = DataType::Int(32);
const auto float32_dtype = DataType::Float(32);

/*
A tensor multiplication c = a * b can be written in terms of respective
quantized tensors, scales and zero points as
S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).

We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
0. The quantized multiplication then becomes
Q_c = S'/S_c Q' + z_c,
which is essentially a requantization of tensor Q' into tensor Q_c.
*/

auto lhs_shifted = Cast(args.lhs, int32_dtype);
auto rhs_shifted = Cast(args.rhs, int32_dtype);

auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
ICHECK(broadcast_attrs != nullptr);

auto lhs_axis = broadcast_attrs->lhs_axis;
auto rhs_axis = broadcast_attrs->rhs_axis;

if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
/*
This is per-tensor quantized multiply.

A tensor multiplication c = a * b can be written in terms of respective
quantized tensors, scales and zero points as
S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).

We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
0. The quantized multiplication then becomes
Q_c = S'/S_c Q' + z_c,
which is essentially a requantization of tensor Q' into tensor Q_c.
*/

auto lhs_shifted = Cast(args.lhs, int32_dtype);
auto rhs_shifted = Cast(args.rhs, int32_dtype);

auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
}

if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
}

// Create a new tensor Q'
output = Multiply(lhs_shifted, rhs_shifted);

// Get the adjusted new scale and zero points.
float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
float new_scale_float = lhs_scale_float * rhs_scale_float;
auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
auto new_input_zero_point = zero_scalar;

// Requantize to get Q_c
output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
args.output_scale, args.output_zero_point, input_type.dtype);
} else if (lhs_axis == rhs_axis) {
/*
This is per-channel quantized multiply, assumming lhs_axis and rhs_axis are the same.
The subtract is done on the specified axis via broadcast. Then, we multiply lhs and rhs.
The output is requantized using new scale and axis. TODO: support different axes.
*/

auto lhs_data = Cast(args.lhs, int32_dtype);
auto rhs_data = Cast(args.rhs, int32_dtype);

auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
// Broadcast lhs zero point if needed
int rank = static_cast<int>(input_type.shape.size());
int axis = (lhs_axis < 0) ? ((rank > 0) ? rank + lhs_axis : 0) : lhs_axis;
Expr lhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.lhs_zero_point,
{
-1,
}),
rank, {axis});
lhs_data = Subtract(lhs_data, Cast(lhs_zero_broadcast, DataType::Int(32)));
}

if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
// Broadcast rhs zero point if needed
int rank = static_cast<int>(input_type.shape.size());
int axis = (rhs_axis < 0) ? ((rank > 0) ? rank + rhs_axis : 0) : rhs_axis;
Expr rhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.rhs_zero_point,
{
-1,
}),
rank, {axis});
rhs_data = Subtract(rhs_data, Cast(rhs_zero_broadcast, DataType::Int(32)));
}

// Create a new tensor Q'
output = Multiply(lhs_data, rhs_data);

// Requantize to get Q_c
auto lhs_scales = GetFloatVectorFromConstant(args.lhs_scale);
auto rhs_scales = GetFloatVectorFromConstant(args.rhs_scale);
std::vector<double> output_multipliers;
for (size_t i = 0; i < lhs_scales.size(); i++) {
double multiplier = static_cast<double>(lhs_scales[i]) * static_cast<double>(rhs_scales[i]);
output_multipliers.push_back(multiplier);
}
auto new_input_scale = MakeConstantTensor(
DataType::Float(32), {(int64_t)output_multipliers.size()}, output_multipliers);

output = Requantize(output, input_type.shape, new_input_scale, zero_scalar, args.output_scale,
args.output_zero_point, input_type.dtype, lhs_axis);

} else {
LOG(FATAL) << "Not supported: lhs_axis and rhs_axis are not the same.";
}

if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
}

// Create a new tensor Q'
auto output = Multiply(lhs_shifted, rhs_shifted);

// Get the adjusted new scale and zero points.
float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
float new_scale_float = lhs_scale_float * rhs_scale_float;
auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
auto new_input_zero_point = zero_scalar;

// Requantize to get Q_c
output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
args.output_scale, args.output_zero_point, input_type.dtype);

return output;
}

Expand Down
Loading