1414
1515from executorch .exir .dialects ._ops import ops as exir_ops
1616
17- from .qnn_constants import QNN_uint16
18-
1917from .utils import get_parameter , is_graph_input , is_graph_output , is_parameter
2018
2119
2624 # Note that there is no int64 tensor data type in Qnn.
2725 torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UNDEFINED ,
2826 torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_8 ,
29- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
27+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UFIXED_POINT_16 ,
3028}
3129QNN_TENSOR_TYPE_MAP = {
3230 torch .bool : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_BOOL_8 ,
3634 torch .int32 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_32 ,
3735 torch .int64 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_INT_64 ,
3836 torch .uint8 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_8 ,
39- QNN_uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
37+ torch . uint16 : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16 ,
4038 float : PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_FLOAT_32 ,
4139}
4240
@@ -170,7 +168,7 @@ def get_quant_encoding_conf(
170168 return self .make_qnn_per_tensor_config (quant_attrs )
171169
172170 def get_quant_tensor_value (
173- self , tensor : torch .Tensor , quant_attrs : Dict , dtype , bitwidth
171+ self , tensor : torch .Tensor , quant_attrs : Dict , quant_configs : Dict
174172 ) -> torch .Tensor :
175173 if quant_attrs ["encoding" ] in PER_TENSOR_ENCODING :
176174 scale = quant_attrs ["scale" ]
@@ -179,16 +177,11 @@ def get_quant_tensor_value(
179177 scale = quant_attrs ["scales" ]
180178 zero_point = quant_attrs ["zero_points" ]
181179
182- # To bypass torch.uint16 quantization is not supported
183- dtype = (
184- torch .int32
185- if dtype == PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_16
186- else quant_attrs ["dtype" ]
187- )
180+ dtype = quant_configs ["dtype" ]
188181
189182 tensor = tensor .div (scale ).add (zero_point ).round ().to (dtype )
190183 # Make the backends access data correctly
191- if bitwidth == 4 :
184+ if quant_configs . get ( " bitwidth" ) == 4 :
192185 mask = torch .full (tensor .size (), 0x0F , dtype = torch .int8 )
193186 tensor = torch .bitwise_and (mask , tensor )
194187 return tensor
@@ -237,7 +230,7 @@ def get_data_type(
237230 <= torch .iinfo (torch .int16 ).max - torch .iinfo (torch .int16 ).min
238231 ):
239232 if unsigned :
240- quant_config ["dtype" ] = QNN_uint16
233+ quant_config ["dtype" ] = torch . uint16
241234 else :
242235 quant_config ["dtype" ] = torch .int16
243236 return QNN_QUANT_TYPE_MAP [quant_config ["dtype" ]]
@@ -328,8 +321,7 @@ def define_tensor(
328321 tensor = self .get_quant_tensor_value (
329322 tensor ,
330323 node .meta ["quant_attrs" ],
331- dtype ,
332- quant_configs .get ("bitwidth" ),
324+ quant_configs ,
333325 )
334326 tensor_wrapper = PyQnnWrapper .TensorWrapper (
335327 tensor_name ,
0 commit comments