2525except :
2626 gemlite = None
2727
28-
2928aten = torch .ops .aten
3029
3130
@@ -35,7 +34,12 @@ def _same_metadata(
3534) -> bool :
3635 kwargs_match = len (self .gemlite_kwargs ) == len (src .gemlite_kwargs )
3736 for k , v in self .gemlite_kwargs .items ():
38- if k != "scale_activations" :
37+ if k in [
38+ "in_features" ,
39+ "out_features" ,
40+ "packing_bitwidth" ,
41+ "elements_per_sample" ,
42+ ]:
3943 kwargs_match = kwargs_match and (v == src .gemlite_kwargs [k ])
4044
4145 return (
@@ -80,6 +84,7 @@ def get_gemlite_aqt_kwargs(
8084 weight ,
8185 group_size = 64 ,
8286 bit_width = 4 ,
87+ packing_bitwidth = None ,
8388 use_hqq = True ,
8489):
8590 if gemlite is None :
@@ -99,6 +104,9 @@ def get_gemlite_aqt_kwargs(
99104 assert group_size is None or bit_width != 8 , (
100105 "gemlite only works with group_size=None for bit_width=8"
101106 )
107+ assert packing_bitwidth in [8 , 16 , 32 , None ], (
108+ f"Invalid packing bitwidth, got { packing_bitwidth } "
109+ )
102110
103111 out_features , in_features = weight .shape
104112 group_size = in_features if group_size is None else group_size
@@ -107,15 +115,17 @@ def get_gemlite_aqt_kwargs(
107115 aqt_kwargs ["_layout" ] = GemlitePackedLayout (
108116 group_size = group_size ,
109117 bit_width = bit_width ,
118+ packing_bitwidth = packing_bitwidth ,
110119 )
111120 aqt_kwargs ["use_hqq" ] = use_hqq
112121 return aqt_kwargs
113122
114123
115124@dataclass (frozen = True )
116125class GemlitePackedLayout (Layout ):
117- group_size : Optional [int ] = 64
126+ group_size : Optional [int ] = 128
118127 bit_width : int = 4
128+ packing_bitwidth : Optional [int ] = None
119129
120130
121131@register_layout (GemlitePackedLayout )
@@ -191,24 +201,36 @@ def from_plain(
191201
192202 group_size , bit_width = _layout .group_size , _layout .bit_width
193203 out_features , in_features = int_data .shape
204+ packing_bitwidth = _layout .packing_bitwidth
194205
195206 if bit_width == 8 and group_size == in_features :
196207 gemlite_linear = gemlite .helper .A16W8 (device = int_data .device ).from_weights (
197208 int_data , scales = scale , bias = None
198209 )
199210 else :
200- gemlite_linear = gemlite .helper .A16Wn (device = int_data .device ).from_weights (
211+ gemlite_linear = gemlite .helper .A16Wn (
212+ device = int_data .device , packing_bitwidth = packing_bitwidth
213+ ).from_weights (
201214 int_data , scale , zero_point , bit_width , group_size , bias = None
202215 )
203216
217+ meta_args = gemlite_linear .get_meta_args ()
204218 gemlite_kwargs = {
205219 "in_features" : in_features ,
206220 "out_features" : out_features ,
207- "meta_args" : gemlite_linear .get_meta_args (),
221+ "packing_bitwidth" : packing_bitwidth ,
222+ "data_contiguous" : gemlite_linear .data_contiguous ,
223+ "elements_per_sample" : gemlite_linear .elements_per_sample ,
224+ "W_group_mode" : gemlite_linear .W_group_mode ,
225+ "meta_args" : meta_args ,
208226 }
209227
210228 packed_weight , scale , zero_point = gemlite_linear .get_tensor_args ()
211229 packed_weight = packed_weight .to (device )
230+ if zero_point is None :
231+ zero_point = torch .tensor (
232+ [[]], device = packed_weight .device , dtype = torch .int32
233+ )
212234
213235 return cls (packed_weight , scale , zero_point , gemlite_kwargs , _layout )
214236
@@ -235,18 +257,39 @@ def _apply_fn_to_data(self, fn):
235257 def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
236258 device = self .packed_weight .device
237259 int_data = (
238- gemlite .bitpack .unpack_over_rows (
239- self .packed_weight .cuda (),
240- W_nbits = self ._layout .bit_width ,
241- num_output_rows = self .gemlite_kwargs ["out_features" ],
242- dtype = torch .uint8 ,
260+ (
261+ gemlite .bitpack .unpack_over_rows (
262+ self .packed_weight .cuda (),
263+ W_nbits = self ._layout .bit_width ,
264+ num_output_rows = self .gemlite_kwargs ["in_features" ],
265+ dtype = torch .uint8 ,
266+ )
243267 )
268+ .to (device )
244269 .t ()
245- .contiguous ()
246- ).to (device )
270+ )
271+
272+ # Preserve col-row major layout
273+ if self .gemlite_kwargs ["data_contiguous" ]:
274+ int_data = int_data .contiguous ()
275+
276+ # Handle FMA mode: W_q * s + z -> (W_q - z) * s
277+ if self .gemlite_kwargs ["W_group_mode" ] == 4 :
278+ scale_min_val = 1e-8
279+ scale = self .scale .clone ().float ()
280+ scale [torch .logical_and (scale >= 0 , scale .abs () <= scale_min_val )] = (
281+ scale_min_val
282+ )
283+ scale [
284+ torch .logical_and (scale < 0 , scale .abs () <= scale_min_val )
285+ ] = - scale_min_val
286+ zero_point = (- self .zero_point .float () / scale ).clamp_ (- 100 , 100 )
287+ zero_point = zero_point .to (self .scale .dtype )
288+ else :
289+ zero_point = self .zero_point
247290
248291 scale = self .scale .t ().contiguous ()
249- zero_point = self . zero_point .t ().contiguous ()
292+ zero_point = zero_point .t ().contiguous ()
250293
251294 return int_data , scale , zero_point
252295
@@ -274,30 +317,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
274317 assert step == 1 , "Only step == 1 is supported in slicing right now"
275318
276319 if dim in [0 , 1 ]:
277- int_data , scale , zero_point = self .get_plain ()
278- data_len = int_data .shape [dim ]
320+ # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T
321+ dim = 1 - dim
322+ packed_weight = self .packed_weight
323+ scale = self .scale
324+ zero_point = self .zero_point
325+
326+ gemlite_kwargs = self .gemlite_kwargs .copy ()
327+ orig_shape = [
328+ gemlite_kwargs ["in_features" ],
329+ gemlite_kwargs ["out_features" ],
330+ ]
331+ elements_per_sample = gemlite_kwargs ["elements_per_sample" ]
332+ data_len = orig_shape [dim ]
279333 scale_len = scale .shape [dim ]
280334 ratio = data_len / scale_len
281335 start_scale = int (start / ratio )
282336 end_scale = int (end / ratio )
283337
284- int_data = aten .slice .Tensor (int_data , dim , start , end , step )
338+ # For packing only the K dimension. This should be flipped for N-dim packing.
339+ div = elements_per_sample if dim == 0 else 1
340+ packed_weight = aten .slice .Tensor (
341+ packed_weight , dim , start // div , end // div , step
342+ )
343+
344+ # Update in_features/out_features
345+ gemlite_kwargs ["in_features" ] = (
346+ packed_weight .shape [0 ] * elements_per_sample
347+ )
348+ gemlite_kwargs ["out_features" ] = packed_weight .shape [1 ]
349+
285350 scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
286351 if zero_point is not None and zero_point .numel () > 0 :
287352 zero_point = aten .slice .Tensor (
288353 zero_point , dim , start_scale , end_scale , step
289354 )
290355 else :
291356 zero_point = None
292- # this is to handle padding
293- int_data , scale , zero_point = self ._layout .post_process (
294- int_data , scale , zero_point , self .block_size
295- )
296-
297- sliced = self .from_plain (
298- int_data , scale , zero_point , self ._layout
299- ) # Will be transposed again
300357
358+ sliced = GemliteAQTTensorImpl (
359+ packed_weight , scale , zero_point , gemlite_kwargs , self ._layout
360+ )
301361 return return_and_correct_aliasing (func , args , kwargs , sliced )
302362
303363 else :
@@ -308,10 +368,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308368 elif func is aten .copy_ .default :
309369 self = args [0 ]
310370 src = args [1 ]
371+
372+ # Handle zero_point = None with symmetric quant
373+ if self .zero_point is None :
374+ self .zero_point = torch .tensor (
375+ [[]], device = self .packed_weight .device , dtype = torch .int32
376+ )
377+
378+ if src .zero_point is None :
379+ src .zero_point = torch .tensor (
380+ [[]], device = src .packed_weight .device , dtype = torch .int32
381+ )
382+
311383 if _same_metadata (self , src ):
312384 self_tensors = self .__tensor_flatten__ ()[0 ]
313385 for tensor_name in self_tensors :
314386 getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
387+ for key in self .gemlite_kwargs :
388+ self .gemlite_kwargs [key ] = src .gemlite_kwargs [key ]
315389 return
316390 raise ValueError (
317391 f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
0 commit comments