Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ def convert_fully_connected(self, op):
input_scale=input_tensor.qnn_params["scale"],
kernel_scale=weight_tensor.qnn_params["scale"],
units=weight_shape[0],
out_dtype="int32",
out_dtype="int64" if output_tensor_type_str == "int16" else "int32",
)
else:
out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])
Expand All @@ -1977,7 +1977,7 @@ def convert_fully_connected(self, op):
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down Expand Up @@ -3175,7 +3175,7 @@ def convert_transpose_conv(self, op):
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DTransposeAttrs>();
ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
data->dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<DenseAttrs>();
ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected quantized dense type(int8, uint8, int16, uint16) for input but was "
<< data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == DataType::Int(64))
<< "Expected quantized dense type(int32, int64) for output but was " << param->out_dtype;

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
Expand Down
5 changes: 3 additions & 2 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;
in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) ||
in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int16, int32, int64] but was " << in_dtype;

const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_ethosn/test_convert_equivalents.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def expected():
@requires_ethosn
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
[("int16", (1, 16, 12, 4), None)],
[("float32", (1, 16, 12, 4), None)],
)
def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape, constant_shape):
"""
Expand Down Expand Up @@ -445,7 +445,7 @@ def expected():
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
[
("int16", (1, 16, 12, 4), None),
("float32", (1, 16, 12, 4), None),
],
)
def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape):
Expand Down
23 changes: 23 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4878,6 +4878,28 @@ def representative_dataset():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_ds_cnn_int16():
"""Test DS_CNN int16 quantized model"""
tflite_model_file = download_testdata(
"https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/"
"models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true",
"ds_cnn_quantized_int16.tflite",
)

with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

data = np.random.uniform(size=(1, 490)).astype("int16")

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input:0")
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Unidirectional Sequence LSTM
# ---------------------
Expand Down Expand Up @@ -5250,3 +5272,4 @@ def test_forward_nms_v5():
test_forward_tflite_float16()

test_forward_tflite_int16()
test_forward_ds_cnn_int16()
Loading