2222from tvm import relay
2323from .. import op as reg
2424
25+ #################################################
26+ # Register the functions for different operators.
27+ #################################################
28+
2529# Registering QNN Conv2D legalization function.
2630@reg .register_qnn_legalize ("qnn.conv2d" )
2731def legalize_qnn_conv2d (attrs , inputs , types ):
28- """Legalizes QNN conv2d op.
32+ return qnn_conv2d_legalize (attrs , inputs , types )
33+
34+ # Registering QNN dense legalization function.
35+ @reg .register_qnn_legalize ("qnn.dense" )
36+ def legalize_qnn_dense (attrs , inputs , types ):
37+ return qnn_dense_legalize (attrs , inputs , types )
38+
39+ # Default to None. If overridden by target, this will not be run.
40+ # Generic QNN Conv2D legalization function.
41+ @tvm .target .generic_func
42+ def qnn_conv2d_legalize (attrs , inputs , types ):
43+ """Default legalization is None."""
44+ return None
45+
46+ # Generic QNN Conv2D legalization function.
47+ @tvm .target .generic_func
48+ def qnn_dense_legalize (attrs , inputs , types ):
49+ """Default legalization is None."""
50+ return None
51+
52+ ###################
53+ # Helper functions.
54+ ###################
55+
56+ # Helper function for lowering in the abscence of fast Int8 arithmetic units.
57+ def helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay_op ):
58+ """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
59+ not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
60+ much more efficiently if the convolution or dense operator input datatypes are int16 instead of
61+ int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
2962
3063 Parameters
3164 ----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
4174 result : tvm.relay.Expr
4275 The legalized expr
4376 """
44- return qnn_conv2d_legalize (attrs , inputs , types )
4577
46- # Generic QNN Conv2D legalization function.
47- @tvm .target .generic_func
48- def qnn_conv2d_legalize (attrs , inputs , types ):
49- """Default legalization is None."""
50- return None
78+ # Collect the input exprs.
79+ data , kernel = inputs
5180
52- # Intel x86 QNN Conv2D legalization function.
53- @qnn_conv2d_legalize .register ('cpu' )
54- def _qnn_conv2d_legalize (attrs , inputs , types ):
55- """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
56- we dont transform. Else, we shift the tensor values and zero points to change the dtype.
81+ input_zp = attrs ['input_zero_point' ]
82+ kernel_zp = attrs ['kernel_zero_point' ]
83+
84+ shift_data = relay .subtract (relay .cast (data , dtype = 'int16' ),
85+ relay .const (input_zp , 'int16' ))
86+ shift_kernel = relay .subtract (relay .cast (kernel , dtype = 'int16' ),
87+ relay .const (kernel_zp , 'int16' ))
88+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
89+ del new_attrs ['kernel_zero_point' ]
90+ del new_attrs ['input_zero_point' ]
91+ return relay_op (shift_data , shift_kernel , ** new_attrs )
92+
93+ # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
94+ def helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay_op ):
95+ """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
96+ are already good, we dont transform. Else, we shift the tensor values and zero points to change
97+ the dtype.
5798
5899 Converting from int8 to uint8 can be done in following manner.
59100
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
82123 The legalized expr
83124 """
84125
85- def _shift (data , out_dtype ):
126+ def _shift (data , zero_point , out_dtype ):
86127 """Shifts (add/subtracts) the qnn tensor with +/-128)"""
87128 if out_dtype == 'uint8' :
88129 shift = 128
89130 elif out_dtype == 'int8' :
90131 shift = - 128
91132 else :
92- raise ValueError ("Unsupport out dtype." )
133+ raise ValueError ("Unsupported out dtype." )
93134 data_modified = relay .cast (data , 'int32' )
94135 data_modified = relay .add (data_modified , relay .const (shift , 'int32' ))
95136 data_modified = relay .cast (data_modified , out_dtype )
96- return data_modified
97-
98- def _is_int8_hw_support (target ):
99- """
100- Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101- and above.
102- """
103- supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
104- return supported_arches .intersection (set (target .options ))
137+ return (data_modified , zero_point + shift )
105138
106139 # Collect the dtypes.
107140 data_dtype = types [0 ].dtype
@@ -110,11 +143,6 @@ def _is_int8_hw_support(target):
110143 # Collect the input exprs.
111144 data , kernel = inputs
112145
113- # The VNNI transformations are applicable only Skylake and above.g
114- target = tvm .target .current_target (allow_none = False )
115- if not _is_int8_hw_support (target ):
116- return None
117-
118146 # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
119147 if data_dtype == 'uint8' and kernel_dtype == 'int8' :
120148 return None
@@ -123,18 +151,118 @@ def _is_int8_hw_support(target):
123151 input_zp = attrs ['input_zero_point' ]
124152 if data_dtype == 'int8' :
125153 # Compute (QA + 128) and (zp_a + 128)
126- data = _shift (data , 'uint8' )
127- input_zp = input_zp + 128
154+ data , input_zp = _shift (data , input_zp , 'uint8' )
128155
129156 # Shift kernel if necessary.
130157 kernel_zp = attrs ['kernel_zero_point' ]
131158 if kernel_dtype == 'uint8' :
132159 # Compute (QA - 128) and (zp_a - 128)
133- kernel = _shift (kernel , 'int8' )
134- kernel_zp = kernel_zp - 128
160+ kernel , kernel_zp = _shift (kernel , kernel_zp , 'int8' )
135161
136162 # Call qnn.conv2d with modified inputs and zero points.
137163 new_attrs = {k : attrs [k ] for k in attrs .keys ()}
138164 new_attrs ['input_zero_point' ] = input_zp
139165 new_attrs ['kernel_zero_point' ] = kernel_zp
140- return relay .qnn .op .conv2d (data , kernel , ** new_attrs )
166+ return relay_op (data , kernel , ** new_attrs )
167+
168+ # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
169+ def helper_change_dtypes_to_be_same (attrs , inputs , types , relay_op ):
170+ """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
171+ many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
172+ conv2d/dense such that both the dtypes are same.
173+
174+ Parameters
175+ ----------
176+ attrs : tvm.attrs.Attrs
177+ Attributes of current convolution
178+ inputs : list of tvm.relay.Expr
179+ The args of the Relay expr to be legalized
180+ types : list of types
181+ List of input and output types
182+
183+ Returns
184+ -------
185+ result : tvm.relay.Expr
186+ The legalized expr
187+ """
188+
189+ def _shift (data , zero_point , out_dtype ):
190+ """Shifts (adds/subtracts) the qnn tensor by 128)"""
191+ if out_dtype == 'uint8' :
192+ shift = 128
193+ elif out_dtype == 'int8' :
194+ shift = - 128
195+ else :
196+ raise ValueError ("Unsupported out dtype." )
197+ data_modified = relay .cast (data , 'int32' )
198+ data_modified = relay .add (data_modified , relay .const (shift , 'int32' ))
199+ data_modified = relay .cast (data_modified , out_dtype )
200+ return (data_modified , zero_point + shift )
201+
202+ # Collect the dtypes.
203+ data_dtype = types [0 ].dtype
204+ kernel_dtype = types [1 ].dtype
205+
206+ if data_dtype == kernel_dtype :
207+ return None
208+
209+ # Collect the input exprs.
210+ data , kernel = inputs
211+
212+ assert 'int8' in data_dtype and 'int8' in kernel_dtype , \
213+ "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
214+
215+ # Shift input if necessary.
216+ input_zp = attrs ['input_zero_point' ]
217+ data , input_zp = _shift (data , input_zp , kernel_dtype )
218+
219+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
220+ new_attrs ['input_zero_point' ] = input_zp
221+ return relay_op (data , kernel , ** new_attrs )
222+
223+ def is_fast_int8_on_intel ():
224+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
225+ target = tvm .target .current_target (allow_none = False )
226+ intel_supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
227+ return intel_supported_arches .intersection (set (target .options ))
228+
229+ def is_fast_int8_on_arm ():
230+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
231+ target = tvm .target .current_target (allow_none = False )
232+ return '+v8.2a,+dotprod' in ' ' .join (target .options )
233+
234+ ########################
235+ # ARM CPU legalizations.
236+ ########################
237+
238+ @qnn_conv2d_legalize .register ('arm_cpu' )
239+ def _qnn_conv2d_legalize_arm_cpu (attrs , inputs , types ):
240+ # ARM prefers the dtypes to be same.
241+ if is_fast_int8_on_arm ():
242+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .conv2d )
243+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
244+
245+ @qnn_dense_legalize .register ('arm_cpu' )
246+ def _qnn_dense_legalize_arm_cpu (attrs , inputs , types ):
247+ # ARM prefers the dtypes to be same.
248+ if is_fast_int8_on_arm ():
249+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .dense )
250+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
251+
252+ ##########################
253+ # Intel CPU legalizations.
254+ ##########################
255+
256+ @qnn_conv2d_legalize .register ('cpu' )
257+ def _qnn_conv2d_legalize_intel_cpu (attrs , inputs , types ):
258+ # The VNNI transformations prefer uint8 x int8 datatypes.
259+ if is_fast_int8_on_intel ():
260+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .conv2d )
261+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
262+
263+ @qnn_dense_legalize .register ('cpu' )
264+ def _qnn_dense_legalize_intel_cpu (attrs , inputs , types ):
265+ # The VNNI transformations prefer uint8 x int8 datatypes.
266+ if is_fast_int8_on_intel ():
267+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .dense )
268+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
0 commit comments