@@ -42,6 +42,8 @@ namespace qnn {
4242
4343bool QnnConv2DRel (const Array<Type>& types, int num_inputs, const Attrs& attrs,
4444 const TypeReporter& reporter) {
45+ // Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
46+ // out_type
4547 ICHECK_EQ (types.size (), 7 );
4648 const auto * data = types[0 ].as <TensorTypeNode>();
4749 const auto * weight = types[1 ].as <TensorTypeNode>();
@@ -57,22 +59,27 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5759 ICHECK (param->out_dtype .bits () > 0 ) << " Output dtype bits should be greater than 0." ;
5860
5961 // Check the types of scale and zero points.
62+ for (size_t i = 2 ; i < 5 ; ++i) {
63+ if (types[i].as <IncompleteTypeNode>()) {
64+ return false ;
65+ }
66+ }
6067 ICHECK (IsScalarType (types[2 ], DataType::Int (32 ))); // input_zero_point
61- ICHECK (IsScalarType (types[3 ], DataType::Int (32 ))); // kernel_zero_point
68+ ICHECK (IsScalarType (types[3 ], DataType::Int (32 ))); // weight_zero_point
6269 ICHECK (IsScalarType (types[4 ], DataType::Float (32 ))); // input_scale
6370 // Kernel scale can be a vector of length output_channels or a scalar.
6471 if (param->groups == 1 ) {
6572 size_t axis = param->kernel_layout .operator std::string ().find (' O' );
6673 ICHECK (axis != std::string::npos) << " Kernel layout attribute is not defined" ;
67- AssignType (types[5 ], DataType::Float (32 ), weight->shape [axis], reporter); // kernel scale
74+ AssignType (types[5 ], DataType::Float (32 ), weight->shape [axis], reporter); // weight_scale
6875 } else {
6976 // Here, total number of output channels depend on depth multiplier.
7077 size_t o_axis = param->kernel_layout .operator std::string ().find (' O' );
7178 size_t i_axis = param->kernel_layout .operator std::string ().find (' I' );
7279 ICHECK (o_axis != std::string::npos || i_axis != std::string::npos)
7380 << " Kernel layout attribute is not defined" ;
7481 AssignType (types[5 ], DataType::Float (32 ), weight->shape [i_axis] * weight->shape [o_axis],
75- reporter); // kernel scale
82+ reporter); // weight_scale
7683 }
7784
7885 // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
0 commit comments