@@ -58,10 +58,13 @@ def from_plain(
5858 ):
5959 pass
6060
61+ @torch ._dynamo .disable
6162 def __repr__ (self ):
62- int_data , scale , zero_point = self .get_plain ()
63- layout_type = self .get_layout_type ()
64- return f"{ self .__class__ .__name__ } (int_data={ int_data } , scale={ scale } , zero_point={ zero_point } , layout_type={ layout_type } )"
63+ # This is a hack, torch.compile tries to trace the __repr__ function which then calls `dequantize` function, causing an error.
64+ # by removing the call to dequantize the error goes away.
65+ # int_data, scale, zero_point = self.get_plain()
66+ # layout_type = self.get_layout_type()
67+ return f"{ self .__class__ .__name__ } " #(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
6568
6669 def _get_to_kwargs (self , * args , ** kwargs ):
6770 device , dtype , _ , memory_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
@@ -152,10 +155,13 @@ def __init__(
152155 self .quant_max = quant_max
153156 self .zero_point_domain = zero_point_domain
154157
158+ @torch ._dynamo .disable
155159 def __repr__ (self ):
156160 return (
157- f"{ self .__class__ .__name__ } (data={ self .dequantize ()} , shape={ self .shape } , "
158- f"device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
161+ f"{ self .__class__ .__name__ } "
162+ # Same hack here
163+ #(data={self.dequantize()}, shape={self.shape}, "
164+ #f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
159165 )
160166
161167 def dequantize (self , output_dtype = None ):
@@ -552,6 +558,8 @@ class MarlinSparseAQTLayout(AQTLayout):
552558 __torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
553559 __torch_function__ = classmethod (_dispatch__torch_function__ )
554560
561+ @staticmethod
562+ @torch ._dynamo .disable
555563 def __new__ (
556564 cls ,
557565 int_data : torch .Tensor ,
@@ -573,6 +581,7 @@ def __new__(
573581 shape = int_data .shape
574582 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
575583
584+ @torch ._dynamo .disable
576585 def __init__ (
577586 self ,
578587 int_data : torch .Tensor ,
@@ -593,8 +602,24 @@ def __init__(
593602 self .group_size = group_size
594603 self .num_bits = num_bits
595604
605+ def __tensor_flatten__ (self ):
606+ return ["int_data" , "scale" , "zero_point" , "meta" ], [self .layout_type , self .original_shape , self .group_size , self .num_bits ]
607+
608+ @classmethod
609+ def __tensor_unflatten__ (
610+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
611+ ):
612+ int_data = tensor_data_dict ["int_data" ]
613+ scale = tensor_data_dict ["scale" ]
614+ zero_point = tensor_data_dict ["zero_point" ]
615+ meta = tensor_data_dict ["meta" ]
616+ layout_type , original_shape , group_size , num_bits = tensor_attributes
617+ return cls (int_data , scale , zero_point , meta , layout_type , original_shape , group_size , num_bits )
618+
619+ @torch ._dynamo .disable
596620 def get_plain (self ):
597621 from torchao .sparsity .marlin import unpack_from_marlin_24 # avoid circular import
622+ unpack_from_marlin_24 = torch ._dynamo .disable (unpack_from_marlin_24 )
598623 int_data_expanded , scales_expanded = unpack_from_marlin_24 (
599624 self .int_data ,
600625 self .scale ,
@@ -606,6 +631,7 @@ def get_plain(self):
606631 return int_data_expanded , scales_expanded , self .zero_point
607632
608633 @classmethod
634+ @torch ._dynamo .disable
609635 def from_plain (
610636 cls ,
611637 int_data : torch .Tensor ,
@@ -674,7 +700,7 @@ def _apply_fn_to_data(self, fn):
674700@MarlinSparseAQTLayout .implements (aten .detach .default )
675701def block_sparse_detach (func , types , args , kwargs ):
676702 return return_and_correct_aliasing (func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach ))
677-
703+
678704
679705@register_layout_cls (TensorCoreTiledLayoutType )
680706class TensorCoreTiledAQTLayout (AQTLayout ):
@@ -920,7 +946,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
920946 tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
921947 # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
922948 y_dot_bf16_w_scales_fused = torch ._cslt_sparse_mm (
923- w_vals_int8 , tmp .t (), alpha = w_scales .to (torch .float32 ), out_dtype = torch .bfloat16
949+ w_vals_int8 , tmp .t (), alpha = w_scales .to (torch .float32 ), out_dtype = torch .bfloat16 ,
924950 ).t ()
925951 y = (y_dot_bf16_w_scales_fused * x_scales .reshape (- 1 , 1 )).reshape (
926952 * x_vals_int8 .shape [:- 1 ], y_dot_bf16_w_scales_fused .shape [- 1 ]
@@ -1037,6 +1063,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
10371063
10381064def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
10391065 return (
1066+ isinstance (weight_tensor , AffineQuantizedTensor ) and
10401067 _aqt_is_uint4 (weight_tensor ) and
10411068 input_tensor .dtype == torch .float16 and
10421069 len (weight_tensor .shape ) == 2 and
@@ -1046,11 +1073,13 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
10461073
10471074def _linear_fp_act_int4_weight_sparse_marlin_impl (input_tensor , weight_tensor , bias ):
10481075 from torchao .sparsity .marlin import marlin_24_workspace , const
1076+ assert isinstance (weight_tensor , AffineQuantizedTensor )
10491077
10501078 sparse_w_int4 = weight_tensor .layout_tensor .int_data
10511079 scale = weight_tensor .layout_tensor .scale
10521080 meta = weight_tensor .layout_tensor .meta
10531081 original_shape = weight_tensor .layout_tensor .original_shape
1082+ print ("original_shape" , original_shape )
10541083 num_bits = weight_tensor .layout_tensor .num_bits
10551084
10561085 # Saves batch size for reshaping back to original shape after the matmul
@@ -1059,13 +1088,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
10591088 batch_size = - 1
10601089 if input_tensor .dim () == 3 :
10611090 batch_size = input_tensor .size (0 )
1062- input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]). contiguous ()
1091+ input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ])
10631092
10641093 size_m = input_tensor .shape [0 ]
10651094 size_n = original_shape [1 ]
10661095 size_k = input_tensor .shape [1 ]
10671096 workspace_24 = marlin_24_workspace (original_shape [1 ])
10681097
1098+ print (size_m , size_n , size_k )
1099+
10691100 # Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
10701101 if size_k % const .TILE != 0 :
10711102 pad_size = find_multiple (size_k , const .TILE )
@@ -1076,11 +1107,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
10761107 input_tensor , sparse_w_int4 , meta , scale ,
10771108 workspace_24 , num_bits , size_m , size_n , size_k
10781109 )
1079- torch .cuda .synchronize ()
10801110
1081- # Reshape back to original shape
10821111 if batch_size != - 1 :
1083- out = out .reshape (batch_size , - 1 , out .shape [- 1 ])
1112+ out = out .view (batch_size , - 1 , out .shape [- 1 ])
10841113
10851114 if bias is not None :
10861115 out += bias .to (out .dtype )
@@ -1113,14 +1142,14 @@ def _(func, types, args, kwargs):
11131142 # using try/except here so that we can have a general fallback when input_tensor/weight_tensor
11141143 # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
11151144 # make the branches easier to understand in `_quantized_linear_op`
1116- try :
1117- return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1118- except :
1119- if isinstance (input_tensor , AffineQuantizedTensor ):
1120- input_tensor = input_tensor .dequantize ()
1121- if isinstance (weight_tensor , AffineQuantizedTensor ):
1122- weight_tensor = weight_tensor .dequantize ()
1123- return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
1145+ # try:
1146+ return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1147+ # except:
1148+ # if isinstance(input_tensor, AffineQuantizedTensor):
1149+ # input_tensor = input_tensor.dequantize()
1150+ # if isinstance(weight_tensor, AffineQuantizedTensor):
1151+ # weight_tensor = weight_tensor.dequantize()
1152+ # return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
11241153
11251154@implements (aten .addmm .default )
11261155def _ (func , types , args , kwargs ):
0 commit comments