@@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
9898 """
9999
100100 def __init__ (self , quant_config : Fp8Config ):
101+ self .fused_module_in_checkpoint = False
101102 self .quant_config = quant_config
102103 self .cutlass_fp8_supported = cutlass_fp8_supported ()
103104
@@ -111,6 +112,7 @@ def _create_scale_param(
111112 scale = Parameter (torch .empty (len (output_partition_sizes ),
112113 dtype = torch .float32 ),
113114 requires_grad = False )
115+ scale [:] = torch .finfo (torch .float8_e4m3fn ).min
114116 layer .register_parameter (scale_name , scale )
115117 set_weight_attrs (
116118 scale , {
@@ -169,11 +171,15 @@ def create_weights(
169171 ** extra_weight_attrs )
170172
171173 def scales_shard_indexer (
172- self , param : torch .Tensor , loaded_weight : torch .Tensor ,
173- shard_id : Union [str , int ]) -> Tuple [torch .Tensor , torch .Tensor ]:
174+ self , param : torch .Tensor , loaded_weight : torch .Tensor ,
175+ shard_id : Optional [Union [str ,
176+ int ]]) -> Tuple [torch .Tensor , torch .Tensor ]:
174177 qkv_idxs = {"q" : 0 , "k" : 1 , "v" : 2 }
175178
176- if isinstance (shard_id , int ):
179+ if shard_id is None :
180+ shard_id = 0
181+ self .fused_module_in_checkpoint = True
182+ elif isinstance (shard_id , int ):
177183 pass
178184 elif isinstance (shard_id , str ):
179185 if shard_id not in qkv_idxs :
@@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
205211 # WEIGHT_SCALE / WEIGHT
206212 # Loop over logical weights, requantizing with single scale.
207213 max_w_scale = layer .weight_scale .max ()
208- start = 0
209- for idx , logical_width in enumerate (layer .logical_widths ):
210- end = start + logical_width
211- weight_dq = per_tensor_dequantize (layer .weight [start :end , :],
212- layer .weight_scale [idx ])
213-
214- layer .weight [start :end , :] = per_tensor_quantize (
215- weight_dq , layer .weight_scale .max ())
216- start = end
214+
215+ if not self .fused_module_in_checkpoint :
216+ start = 0
217+ for idx , logical_width in enumerate (layer .logical_widths ):
218+ end = start + logical_width
219+ weight_dq = per_tensor_dequantize (
220+ layer .weight [start :end , :], layer .weight_scale [idx ])
221+
222+ layer .weight [start :end , :] = per_tensor_quantize (
223+ weight_dq , layer .weight_scale .max ())
224+ start = end
217225 layer .weight_scale = Parameter (max_w_scale , requires_grad = False )
218226
219227 # WEIGHT
@@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
227235 if self .quant_config .activation_scheme == "dynamic" :
228236 layer .input_scale = None
229237 elif self .quant_config .activation_scheme == "static" :
230- if not all_close_1d (layer .input_scale ):
231- raise ValueError (
232- "All the input_scales for the logical weights of a "
233- f"layer must be equal. But got { layer .input_scale } " )
234238 layer .input_scale = Parameter (layer .input_scale .max (),
235239 requires_grad = False )
236240 else :
@@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
317321 del layer .kv_scale
318322
319323
320- def all_close_1d (x : torch .Tensor ) -> bool :
321- assert len (x .shape ) == 1
322- return all (torch .allclose (x [0 ], x [i ]) for i in range (x .shape [0 ]))
323-
324-
325324def per_tensor_quantize (tensor : torch .Tensor ,
326325 inv_scale : Union [float , torch .Tensor ]) -> torch .Tensor :
327326 finfo = torch .finfo (torch .float8_e4m3fn )
0 commit comments