44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from typing import Tuple
8+
79import torch
810from executorch .exir .scalar_type import ScalarType
911from torch .library import impl , Library
1012
13+ from .utils import get_conv1d_output_size
14+
1115lib = Library ("xtensa" , "DEF" )
1216
1317lib .define (
2529)
2630
2731lib .define (
28- "quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32+ "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
33+ )
34+ lib .define (
35+ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
36+ )
37+
38+ lib .define (
39+ "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
2940)
3041lib .define (
31- "quantized_linear_pt2 .out(Tensor src , Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale , int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point , *, Tensor(a!) out) -> Tensor(a!)"
42+ "quantized_conv .out(Tensor input , Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups , int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
3243)
3344
3445m = Library ("xtensa" , "IMPL" , "Meta" )
@@ -58,17 +69,15 @@ def dequantize_per_tensor_meta(
5869 return input .new_empty (input .size (), dtype = torch .float )
5970
6071
61- @impl (m , "quantized_linear_pt2 " )
62- def quantized_linear_pt2_meta (
72+ @impl (m , "quantized_linear " )
73+ def quantized_linear_meta (
6374 src : torch .Tensor ,
6475 weight : torch .Tensor ,
6576 bias : torch .Tensor ,
66- in_scale : float ,
6777 in_zero_point : int ,
68- weight_scale : float ,
69- weight_zero_point : int ,
70- out_multiplier : int ,
71- out_shift : int ,
78+ weight_zero_point : torch .Tensor ,
79+ out_multiplier : torch .Tensor ,
80+ out_shift : torch .Tensor ,
7281 out_zero_point : int ,
7382):
7483 # src comes in shape [leading_dims, in_dim]
@@ -79,3 +88,35 @@ def quantized_linear_pt2_meta(
7988 assert len (weight_size ) == 2
8089 out_size [- 1 ] = weight_size [0 ]
8190 return src .new_empty (out_size , dtype = torch .uint8 )
91+
92+
93+ @impl (m , "quantized_conv" )
94+ def quantized_conv_meta (
95+ input : torch .Tensor ,
96+ weight : torch .Tensor ,
97+ bias : torch .Tensor ,
98+ stride : Tuple [int ],
99+ padding : Tuple [int ],
100+ dilation : Tuple [int ],
101+ groups : int ,
102+ in_zero_point : int ,
103+ weight_zero_point : torch .Tensor ,
104+ bias_scale : torch .Tensor ,
105+ output_scale : float ,
106+ output_zero_point : int ,
107+ out_multiplier : torch .Tensor ,
108+ out_shift : torch .Tensor ,
109+ channel_last : bool = False ,
110+ ):
111+ out_channels , _in_channels , * kernel_size = weight .shape
112+ in_size = input .shape
113+ # Assert that the input tensor has at least 3 dimensions, and at most 6
114+ assert len (in_size ) > 2
115+ assert len (in_size ) < 6
116+
117+ # Compute the output tensor size
118+ output_size = get_conv1d_output_size (
119+ in_size , out_channels , stride [0 ], padding [0 ], dilation [0 ], kernel_size [0 ]
120+ )
121+
122+ return input .new_empty (output_size , dtype = input .dtype )
0 commit comments