1111
1212from typing import Any , Tuple
1313
14- import executorch .backends .arm .tosa_specification as tosa_specification
15-
14+ import serializer .tosa_serializer as ts # type: ignore
1615import torch .fx
1716import torch .fx .node
1817
@@ -247,25 +246,18 @@ def build_rescale_to_int32(
247246) -> Any :
248247 input_A_rescaled_to_int32 = None
249248
250- if isinstance (tosa_spec , tosa_specification .Tosa_1_00 ):
251- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
252- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
253- import serializer .tosa_serializer as ts # type: ignore
254-
255- input_A_rescaled_to_int32 = tosa_fb .addIntermediate (
256- input_arg .shape , ts .DType .INT32
257- )
249+ input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input_arg .shape , ts .DType .INT32 )
258250
259- build_rescale (
260- tosa_fb ,
261- [rescale_scale ],
262- input_arg ,
263- input_A_rescaled_to_int32 .name ,
264- ts .DType .INT32 ,
265- [input_zp ],
266- [0 ],
267- rounding_mode = RoundingMode .SINGLE_ROUND ,
268- ) # type: ignore[call-arg]
251+ build_rescale (
252+ tosa_fb ,
253+ [rescale_scale ],
254+ input_arg ,
255+ input_A_rescaled_to_int32 .name ,
256+ ts .DType .INT32 ,
257+ [input_zp ],
258+ [0 ],
259+ rounding_mode = RoundingMode .SINGLE_ROUND ,
260+ ) # type: ignore[call-arg]
269261
270262 return input_A_rescaled_to_int32
271263
@@ -281,21 +273,19 @@ def build_rescale_from_int32(
281273 per_channel : bool = False ,
282274 tosa_spec = None ,
283275) -> None :
284- if isinstance (tosa_spec , tosa_specification .Tosa_1_00 ):
285- import serializer .tosa_serializer as ts # type: ignore
286-
287- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
288- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
289- build_rescale (
290- tosa_fb ,
291- [rescale_scale ],
292- input_node ,
293- output_name = output_name ,
294- output_type = ts .DType .INT8 ,
295- input_zp = [0 ],
296- output_zp = [output_zp ],
297- rounding_mode = RoundingMode .SINGLE_ROUND ,
298- ) # type: ignore[call-arg]
276+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
277+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
278+ build_rescale (
279+ tosa_fb ,
280+ [rescale_scale ],
281+ input_node ,
282+ output_name = output_name ,
283+ output_type = ts .DType .INT8 ,
284+ input_zp = [0 ],
285+ output_zp = [output_zp ],
286+ rounding_mode = RoundingMode .SINGLE_ROUND ,
287+ ) # type: ignore[call-arg]
288+
299289 return
300290
301291
@@ -318,18 +308,17 @@ def build_rescale_conv_output(
318308 (inp * w ) / out for inp , w , out in zip (input_scale , weight_scale , output_scale )
319309 ]
320310
321- if isinstance (tosa_spec [0 ], tosa_specification .Tosa_1_00 ):
322- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
323- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
324- build_rescale (
325- tosa_fb = tosa_fb ,
326- scale = post_conv2d_scale ,
327- input_node = op ,
328- output_name = output_name ,
329- output_type = output_type ,
330- input_zp = [0 ],
331- output_zp = output_zp ,
332- rounding_mode = RoundingMode .SINGLE_ROUND ,
333- per_channel = isinstance (weight_scale , torch .Tensor ),
334- ) # type: ignore[call-arg]
311+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
312+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
313+ build_rescale (
314+ tosa_fb = tosa_fb ,
315+ scale = post_conv2d_scale ,
316+ input_node = op ,
317+ output_name = output_name ,
318+ output_type = output_type ,
319+ input_zp = [0 ],
320+ output_zp = output_zp ,
321+ rounding_mode = RoundingMode .SINGLE_ROUND ,
322+ per_channel = isinstance (weight_scale , torch .Tensor ),
323+ ) # type: ignore[call-arg]
335324 return
0 commit comments