|  | 
|  | 1 | +/* | 
|  | 2 | + * Licensed to the Apache Software Foundation (ASF) under one | 
|  | 3 | + * or more contributor license agreements.  See the NOTICE file | 
|  | 4 | + * distributed with this work for additional information | 
|  | 5 | + * regarding copyright ownership.  The ASF licenses this file | 
|  | 6 | + * to you under the Apache License, Version 2.0 (the | 
|  | 7 | + * "License"); you may not use this file except in compliance | 
|  | 8 | + * with the License.  You may obtain a copy of the License at | 
|  | 9 | + * | 
|  | 10 | + *   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | + * | 
|  | 12 | + * Unless required by applicable law or agreed to in writing, | 
|  | 13 | + * software distributed under the License is distributed on an | 
|  | 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | 15 | + * KIND, either express or implied.  See the License for the | 
|  | 16 | + * specific language governing permissions and limitations | 
|  | 17 | + * under the License. | 
|  | 18 | + */ | 
|  | 19 | + | 
|  | 20 | +/*! | 
|  | 21 | + *  Copyright (c) 2019 by Contributors | 
|  | 22 | + * \file src/relay/qnn/op/concatenate.cc | 
|  | 23 | + * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. | 
|  | 24 | + */ | 
|  | 25 | + | 
|  | 26 | +#include <tvm/ir.h> | 
|  | 27 | +#include <tvm/relay/analysis.h> | 
|  | 28 | +#include <tvm/relay/op_attr_types.h> | 
|  | 29 | +#include <tvm/relay/qnn/attrs.h> | 
|  | 30 | +#include "../../op/tensor/transform.h" | 
|  | 31 | +#include "../../pass/pattern_util.h" | 
|  | 32 | +#include "../util.h" | 
|  | 33 | + | 
|  | 34 | +namespace tvm { | 
|  | 35 | +namespace relay { | 
|  | 36 | +namespace qnn { | 
|  | 37 | + | 
|  | 38 | +TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); | 
|  | 39 | + | 
|  | 40 | +Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales, | 
|  | 41 | +                        Array<tvm::Expr> input_zero_points, double output_scale, | 
|  | 42 | +                        int32_t output_zero_point, int axis) { | 
|  | 43 | +  auto attrs = make_node<QnnConcatenateAttrs>(); | 
|  | 44 | +  attrs->input_scales = std::move(input_scales); | 
|  | 45 | +  attrs->input_zero_points = std::move(input_zero_points); | 
|  | 46 | +  attrs->output_scale = output_scale; | 
|  | 47 | +  attrs->output_zero_point = output_zero_point; | 
|  | 48 | +  attrs->axis = axis; | 
|  | 49 | +  static const Op& op = Op::Get("qnn.concatenate"); | 
|  | 50 | +  return CallNode::make(op, {data}, Attrs(attrs), {}); | 
|  | 51 | +} | 
|  | 52 | + | 
|  | 53 | +/* | 
|  | 54 | + * \brief Canonicalizes the QNN concatenate op. | 
|  | 55 | + * \param attrs The QNN concatenate attrs. | 
|  | 56 | + * \param new_args The new mutated args to the call node. | 
|  | 57 | + * \param arg_types The types of input and output. | 
|  | 58 | + * \return The sequence of Relay ops for concatenate op. | 
|  | 59 | + */ | 
|  | 60 | +Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, | 
|  | 61 | +                                const Array<tvm::relay::Type>& arg_types) { | 
|  | 62 | +  // Get the attrs. | 
|  | 63 | +  CHECK_EQ(new_args.size(), 1); | 
|  | 64 | +  auto& data = new_args[0]; | 
|  | 65 | +  const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>(); | 
|  | 66 | +  CHECK(concatenate_attrs != nullptr); | 
|  | 67 | +  auto input_scales = concatenate_attrs->input_scales; | 
|  | 68 | +  auto input_zero_points = concatenate_attrs->input_zero_points; | 
|  | 69 | +  auto output_scale = concatenate_attrs->output_scale; | 
|  | 70 | +  auto output_zero_point = concatenate_attrs->output_zero_point; | 
|  | 71 | + | 
|  | 72 | +  // Get the input dtype and shape. | 
|  | 73 | +  CHECK_GE(arg_types.size(), 1); | 
|  | 74 | +  auto tuple_type = arg_types[0].as<TupleTypeNode>(); | 
|  | 75 | +  CHECK(tuple_type != nullptr); | 
|  | 76 | + | 
|  | 77 | +  // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in | 
|  | 78 | +  // the start, we can insert requantize at the end if and only if all the input tensors have same | 
|  | 79 | +  // qnn params. This can be done in future. | 
|  | 80 | + | 
|  | 81 | +  // If the output qnn params do not match the input qnn params, we can call requantize on the input | 
|  | 82 | +  // expr first, followed by a concatenate on the requantized input exprs. | 
|  | 83 | + | 
|  | 84 | +  auto tuple_data = data.as<TupleNode>(); | 
|  | 85 | +  CHECK(tuple_data != nullptr); | 
|  | 86 | + | 
|  | 87 | +  int idx = 0; | 
|  | 88 | +  Array<Expr> requantized_exprs; | 
|  | 89 | +  for (auto quantized_expr : tuple_data->fields) { | 
|  | 90 | +    // Get the input scale for the idx quantized input tensor. | 
|  | 91 | +    auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>(); | 
|  | 92 | +    CHECK(input_scale_expr != nullptr); | 
|  | 93 | +    auto input_scale = input_scale_expr->value; | 
|  | 94 | + | 
|  | 95 | +    // Get the zero point for the idx quantized input tensor. | 
|  | 96 | +    auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>(); | 
|  | 97 | +    CHECK(input_zero_point_expr != nullptr); | 
|  | 98 | +    auto input_zero_point = input_zero_point_expr->value; | 
|  | 99 | + | 
|  | 100 | +    // Check if output and input qnn params are same. If not, requantize. | 
|  | 101 | +    if (input_scale != output_scale || input_zero_point != output_zero_point) { | 
|  | 102 | +      // Get the input shape and dtype. | 
|  | 103 | +      auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>(); | 
|  | 104 | +      auto input_dtype = tensor_type->dtype; | 
|  | 105 | +      auto input_shape = tensor_type->shape; | 
|  | 106 | + | 
|  | 107 | +      // Requantize the input. | 
|  | 108 | +      auto requantized_expr = Requantize(quantized_expr, input_shape, input_scale, input_zero_point, | 
|  | 109 | +                                         output_scale, output_zero_point, input_dtype); | 
|  | 110 | +      requantized_exprs.push_back(requantized_expr); | 
|  | 111 | +    } else { | 
|  | 112 | +      requantized_exprs.push_back(quantized_expr); | 
|  | 113 | +    } | 
|  | 114 | +    idx++; | 
|  | 115 | +  } | 
|  | 116 | +  return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis); | 
|  | 117 | +} | 
|  | 118 | + | 
|  | 119 | +RELAY_REGISTER_OP("qnn.concatenate") | 
|  | 120 | +.describe(R"code(Concatenate the quantized input tensors along the given axis. | 
|  | 121 | +)code" TVM_ADD_FILELINE) | 
|  | 122 | +.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs") | 
|  | 123 | +.set_num_inputs(1) | 
|  | 124 | +.add_argument("data", "Tensor", "The tensor to concatenate.") | 
|  | 125 | +.set_support_level(11) | 
|  | 126 | +.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>) | 
|  | 127 | +.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize); | 
|  | 128 | + | 
|  | 129 | +TVM_REGISTER_API("relay.qnn.op._make.concatenate") | 
|  | 130 | +.set_body_typed(MakeQnnConcatenate); | 
|  | 131 | + | 
|  | 132 | +}  // namespace qnn | 
|  | 133 | +}  // namespace relay | 
|  | 134 | +}  // namespace tvm | 
0 commit comments