44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from typing import Tuple
8+
79import torch
10+ from .utils import get_conv1d_output_size
811from executorch .exir .scalar_type import ScalarType
912from torch .library import impl , Library
1013
2528)
2629
2730lib .define (
28- "quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
31+ "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32+ )
33+ lib .define (
34+ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
35+ )
36+
37+ lib .define (
38+ "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
2939)
3040lib .define (
31- "quantized_linear_pt2 .out(Tensor src , Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale , int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point , *, Tensor(a!) out) -> Tensor(a!)"
41+ "quantized_conv .out(Tensor input , Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups , int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
3242)
3343
3444m = Library ("xtensa" , "IMPL" , "Meta" )
@@ -58,17 +68,15 @@ def dequantize_per_tensor_meta(
5868 return input .new_empty (input .size (), dtype = torch .float )
5969
6070
61- @impl (m , "quantized_linear_pt2 " )
62- def quantized_linear_pt2_meta (
71+ @impl (m , "quantized_linear " )
72+ def quantized_linear_meta (
6373 src : torch .Tensor ,
6474 weight : torch .Tensor ,
6575 bias : torch .Tensor ,
66- in_scale : float ,
6776 in_zero_point : int ,
68- weight_scale : float ,
69- weight_zero_point : int ,
70- out_multiplier : int ,
71- out_shift : int ,
77+ weight_zero_point : torch .Tensor ,
78+ out_multiplier : torch .Tensor ,
79+ out_shift : torch .Tensor ,
7280 out_zero_point : int ,
7381):
7482 # src comes in shape [leading_dims, in_dim]
@@ -79,3 +87,35 @@ def quantized_linear_pt2_meta(
7987 assert len (weight_size ) == 2
8088 out_size [- 1 ] = weight_size [0 ]
8189 return src .new_empty (out_size , dtype = torch .uint8 )
90+
91+
92+ @impl (m , "quantized_conv" )
93+ def quantized_conv_meta (
94+ input : torch .Tensor ,
95+ weight : torch .Tensor ,
96+ bias : torch .Tensor ,
97+ stride : Tuple [int ],
98+ padding : Tuple [int ],
99+ dilation : Tuple [int ],
100+ groups : int ,
101+ in_zero_point : int ,
102+ weight_zero_point : torch .Tensor ,
103+ bias_scale : torch .Tensor ,
104+ output_scale : float ,
105+ output_zero_point : int ,
106+ out_multiplier : torch .Tensor ,
107+ out_shift : torch .Tensor ,
108+ channel_last : bool = False ,
109+ ):
110+ out_channels , _in_channels , * kernel_size = weight .shape
111+ in_size = input .shape
112+ # Assert that the input tensor has at least 3 dimensions, and at most 6
113+ assert len (in_size ) > 2
114+ assert len (in_size ) < 6
115+
116+ # Compute the output tensor size
117+ output_size = get_conv1d_output_size (
118+ in_size , out_channels , stride [0 ], padding [0 ], dilation [0 ], kernel_size [0 ]
119+ )
120+
121+ return input .new_empty (output_size , dtype = input .dtype )
0 commit comments