@@ -27,7 +27,7 @@ class UintxTensor(torch.Tensor):
2727 int4_shard (torch.Tensor): 4 bit packed shard
2828 int2_shard (torch.Tensor): 2 bit packed shard
2929 int1_shard (torch.Tensor): 1 bit packed shard
30- bit_size (int): element size in bits
30+ bit_width (int): number of bits for each element
3131 pack_dim: (int) dimension to pack along
3232 """
3333 bits_to_shard = {
@@ -43,71 +43,71 @@ def __new__(
4343 cls ,
4444 shards : List [torch .Tensor ],
4545 packed_shape : List [int ],
46- bit_size : int ,
46+ bit_width : int ,
4747 pack_dim : int = - 1 ,
4848 ):
4949 kwargs = {"device" : shards [0 ].device }
5050 kwargs ["device" ] = shards [0 ].device
5151 kwargs ["layout" ] = shards [0 ].layout
5252 kwargs ["requires_grad" ] = False
5353 kwargs ["dtype" ] = torch .uint8
54- return torch .Tensor ._make_wrapper_subclass (cls , packed_shape , ** kwargs )
54+ return torch .Tensor ._make_wrapper_subclass (cls , packed_shape , ** kwargs )
5555
5656 def __init__ (
5757 self ,
5858 shards : List [torch .Tensor ],
5959 packed_shape : List [int ],
60- bit_size : int ,
60+ bit_width : int ,
6161 pack_dim : int = - 1 ,
6262 ):
63- for i , attrib in enumerate (self .bits_to_shard [bit_size ]):
63+ for i , attrib in enumerate (self .bits_to_shard [bit_width ]):
6464 setattr (self , attrib , shards [i ])
65-
65+
6666 self .packed_shape = packed_shape
67- self .bit_size = bit_size
67+ self .bit_width = bit_width
6868 self .pack_dim = pack_dim
69-
69+
7070 def get_shards (self ):
71- return [getattr (self ,i ) for i in self .__class__ .bits_to_shard [self .bit_size ]]
72-
71+ return [getattr (self ,i ) for i in self .__class__ .bits_to_shard [self .bit_width ]]
72+
7373 def __repr__ (self ):
74- return f"Int{ self .bit_size } Tensor(shape = { self .packed_shape } , data = { unpack (self .get_shards (), self .bit_size , dim = self .pack_dim )} )"
75-
74+ return f"Int{ self .bit_width } Tensor(shape = { self .packed_shape } , data = { unpack (self .get_shards (), self .bit_width , dim = self .pack_dim )} )"
75+
7676 def __tensor_flatten__ (self ):
77- return self .__class__ .bits_to_shard [self .bit_size ], [self .packed_shape , self .bit_size , self .pack_dim ]
78-
77+ return self .__class__ .bits_to_shard [self .bit_width ], [self .packed_shape , self .bit_width , self .pack_dim ]
78+
7979 @classmethod
8080 def __tensor_unflatten__ (
8181 cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
8282 ):
8383 shards = list (tensor_data_dict .values ())
84- packed_shape , bit_size , pack_dim = tensor_attributes
85- return cls (shards , packed_shape , bit_size , pack_dim )
84+ packed_shape , bit_width , pack_dim = tensor_attributes
85+ return cls (shards , packed_shape , bit_width , pack_dim )
8686
8787 implements = classmethod (_implements )
8888 __torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
8989 __torch_function__ = classmethod (_dispatch__torch_function__ )
9090
9191 def get_plain (self ):
92- return unpack (self .get_shards (), self .bit_size , dim = self .pack_dim )
93-
92+ return unpack (self .get_shards (), self .bit_width , dim = self .pack_dim )
93+
9494 # temporary until kernels on packed tensors are created
9595 def apply_transformation (self , fn ):
9696 og = self .get_plain ()
9797 new = fn (og )
98- return self .from_uint8 (new , self .bit_size , self .pack_dim )
99-
98+ return self .from_uint8 (new , self .bit_width , self .pack_dim )
99+
100100 # temporary until kernels on packed tensors are created
101101 def apply_fn_to_shards (self , fn ):
102102 new_shards = [fn (shard ) for shard in self .get_shards ()]
103- return self .__class__ (new_shards , self .packed_shape , self .bit_size , self .pack_dim )
104-
103+ return self .__class__ (new_shards , self .packed_shape , self .bit_width , self .pack_dim )
104+
105105 @classmethod
106- def from_uint8 (cls , int_data : torch .Tensor , bit_size , pack_dim : int = - 1 ):
107- shards = pack (int_data , bit_size , dim = pack_dim )
106+ def from_uint8 (cls , int_data : torch .Tensor , bit_width , pack_dim : int = - 1 ):
107+ shards = pack (int_data , bit_width , dim = pack_dim )
108108 shape = list (int_data .shape )
109- shape [pack_dim ] = shape [pack_dim ] * bit_size // 8
110- return cls (shards , int_data .shape , bit_size , pack_dim )
109+ shape [pack_dim ] = shape [pack_dim ] * bit_width // 8
110+ return cls (shards , int_data .shape , bit_width , pack_dim )
111111
112112
113113implements = UintxTensor .implements
@@ -118,19 +118,19 @@ def _(func, types, args, kwargs):
118118 return return_and_correct_aliasing (
119119 func , args , kwargs , args [0 ].apply_fn_to_shards (torch .detach )
120120 )
121-
121+
122122@implements (aten .view .default )
123123def _ (func , types , args , kwargs ):
124124 return return_and_correct_aliasing (
125125 func , args , kwargs , args [0 ].apply_transformation (lambda x : x .view (* args [1 :]))
126126 )
127-
127+
128128@implements (aten ._to_copy .default )
129129def _ (func , types , args , kwargs ):
130130 return return_and_correct_aliasing (
131131 func , args , kwargs , args [0 ]
132132 )
133-
133+
134134@implements (aten .sub .Tensor )
135135def _ (func , types , args , kwargs ):
136136 return return_and_correct_aliasing (
@@ -147,18 +147,18 @@ def _(func, types, args, kwargs):
147147
148148@dataclass (frozen = True )
149149class UintxLayoutType (LayoutType ):
150- bit_size : int
150+ bit_width : int
151151 pack_dim : int = - 1
152-
152+
153153 def post_process (self , input : torch .Tensor ) -> torch .Tensor :
154- return to_uintx (input , self .bit_size , self .pack_dim )
154+ return to_uintx (input , self .bit_width , self .pack_dim )
155155
156156@register_layout_cls (UintxLayoutType )
157157class UintxAQTLayout (PlainAQTLayout ):
158-
158+
159159 def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
160160 return self .int_data .get_plain (), self .scale , self .zero_point
161-
161+
162162 @classmethod
163163 def from_plain (
164164 cls ,
@@ -169,39 +169,3 @@ def from_plain(
169169 ):
170170 assert isinstance (layout_type , UintxLayoutType )
171171 return cls (int_data , scale , zero_point , layout_type )
172-
173-
174- def uintx_affine_weight_only (bit_size , group_size = 64 , pack_dim = - 1 ):
175- """
176- Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
177- x is the number of bits specified by the `nbits` argument
178- """
179- from torchao .quantization .quant_primitives import (
180- MappingType ,
181- ZeroPointDomain ,
182- choose_qparams_affine ,
183- quantize_affine ,
184- dequantize_affine ,
185- )
186- from torchao .dtypes import to_affine_quantized
187- from torchao .quantization .quant_api import _get_linear_subclass_inserter
188- def apply_uintx_weight_only_quant (weight ):
189-
190- layout_type = UintxLayoutType (bit_size = bit_size , pack_dim = pack_dim )
191- mapping_type = MappingType .ASYMMETRIC
192- block_size = (1 , group_size )
193- quant_min = 0
194- quant_max = 2 ** bit_size - 1
195- eps = torch .finfo (torch .float32 ).eps
196- zero_point_dtype = torch .int32
197- zero_point_domain = ZeroPointDomain .INT
198-
199- return to_affine_quantized (
200- weight , mapping_type , block_size , torch .uint8 ,
201- quant_min = quant_min , quant_max = quant_max ,
202- eps = eps , zero_point_dtype = zero_point_dtype ,
203- zero_point_domain = zero_point_domain ,
204- layout_type = layout_type ,
205- )
206-
207- return _get_linear_subclass_inserter (apply_uintx_weight_only_quant )
0 commit comments