@@ -248,6 +248,10 @@ def __init__(
248248 self .register_parameter ("bias" , None )
249249
250250 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
251+ # Special case for Fp8 scales.
252+ fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
253+ None )
254+
251255 tp_rank = get_tensor_model_parallel_rank ()
252256 output_dim = getattr (param , "output_dim" , None )
253257 param_data = param .data
@@ -256,6 +260,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
256260 start_idx = tp_rank * shard_size
257261 loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
258262 shard_size )
263+ # Special case for Fp8 scales.
264+ elif fp8_scales_shard_indexer is not None :
265+ param_data , loaded_weight = fp8_scales_shard_indexer (param_data ,
266+ loaded_weight ,
267+ shard_id = 0 )
268+
259269 assert param_data .shape == loaded_weight .shape
260270 param_data .copy_ (loaded_weight )
261271
@@ -325,7 +335,12 @@ def weight_loader(self,
325335
326336 param_data = param .data
327337 output_dim = getattr (param , "output_dim" , None )
338+ # Special case for AQLM codebooks.
328339 is_metadata = getattr (param , "is_metadata" , False )
340+ # Special case for Fp8 scales.
341+ fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
342+ None )
343+
329344 if loaded_shard_id is None :
330345 # Loaded weight is already packed.
331346 if output_dim is None :
@@ -339,14 +354,13 @@ def weight_loader(self,
339354 current_shard_offset += output_size
340355 packed_dim = getattr (param , "packed_dim" , None )
341356 for shard_id , shard_offset , shard_size in shard_offsets :
357+ # Special case for Quantization.
342358 # If quantized, we need to adjust the offset and size to account
343359 # for the packing.
344360 if packed_dim == output_dim :
345361 shard_size = shard_size // param .pack_factor
346362 shard_offset = shard_offset // param .pack_factor
347-
348- # If marlin, we need to adjust the offset and size to
349- # account for the tiling.
363+ # Special case for Marlin.
350364 shard_size , shard_offset = adjust_marlin_shard (
351365 param , shard_size , shard_offset )
352366
@@ -361,15 +375,14 @@ def weight_loader(self,
361375 if output_dim is not None :
362376 shard_offset = sum (self .output_sizes [:loaded_shard_id ]) // tp_size
363377 shard_size = self .output_sizes [loaded_shard_id ] // tp_size
378+ # Special case for quantization.
364379 # If quantized, we need to adjust the offset and size to account
365380 # for the packing.
366381 packed_dim = getattr (param , "packed_dim" , None )
367382 if packed_dim == output_dim :
368383 shard_size = shard_size // param .pack_factor
369384 shard_offset = shard_offset // param .pack_factor
370-
371- # If marlin, we need to adjust the offset and size to
372- # account for the tiling.
385+ # Special case for Marlin.
373386 shard_size , shard_offset = adjust_marlin_shard (
374387 param , shard_size , shard_offset )
375388
@@ -378,11 +391,17 @@ def weight_loader(self,
378391 start_idx = tp_rank * shard_size
379392 loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
380393 shard_size )
394+ # Special case for AQLM codebooks.
381395 elif is_metadata :
382396 # metadata indicates fixed size concatenated along dim 0
383397 shard_size = loaded_weight .shape [0 ]
384398 shard_offset = loaded_shard_id * shard_size
385399 param_data = param_data .narrow (0 , shard_offset , shard_size )
400+ # Special case for Fp8 scales.
401+ elif fp8_scales_shard_indexer is not None :
402+ param_data , loaded_weight = fp8_scales_shard_indexer (
403+ param_data , loaded_weight , loaded_shard_id )
404+
386405 else :
387406 ignore_warning = getattr (param , "ignore_warning" , False )
388407 if not ignore_warning :
@@ -477,7 +496,11 @@ def weight_loader(self,
477496 loaded_shard_id : Optional [str ] = None ):
478497 param_data = param .data
479498 output_dim = getattr (param , "output_dim" , None )
499+ # Special case for AQLM codebooks.
480500 is_metadata = getattr (param , "is_metadata" , False )
501+ # Special case for Fp8 scales.
502+ fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
503+ None )
481504
482505 if loaded_shard_id is None :
483506 # Loaded weight is already packed.
@@ -495,14 +518,14 @@ def weight_loader(self,
495518 ]
496519 packed_dim = getattr (param , "packed_dim" , None )
497520 for shard_id , shard_offset , shard_size in shard_offsets :
521+ # Special case for Quantized Weights.
498522 # If quantized, we need to adjust the offset and size to account
499523 # for the packing.
500524 if packed_dim == output_dim :
501525 shard_size = shard_size // param .pack_factor
502526 shard_offset = shard_offset // param .pack_factor
503527
504- # If marlin, we need to adjust the offset and size to
505- # account for the tiling.
528+ # Special case for Marlin.
506529 shard_size , shard_offset = adjust_marlin_shard (
507530 param , shard_size , shard_offset )
508531
@@ -524,15 +547,15 @@ def weight_loader(self,
524547 shard_offset = (self .num_heads +
525548 self .num_kv_heads ) * self .head_size
526549 shard_size = self .num_kv_heads * self .head_size
550+ # Special case for Quantized Weights.
527551 # If quantized, we need to adjust the offset and size to account
528552 # for the packing.
529553 packed_dim = getattr (param , "packed_dim" , None )
530554 if packed_dim == output_dim :
531555 shard_size = shard_size // param .pack_factor
532556 shard_offset = shard_offset // param .pack_factor
533557
534- # If marlin, we need to adjust the offset and size to
535- # account for the tiling.
558+ # Special case for Marlin.
536559 shard_size , shard_offset = adjust_marlin_shard (
537560 param , shard_size , shard_offset )
538561
@@ -545,12 +568,17 @@ def weight_loader(self,
545568 start_idx = shard_id * shard_size
546569 loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
547570 shard_size )
571+ # Special case for for AQLM codebooks.
548572 elif is_metadata :
549573 # metadata indicates fixed size concatenated along dim 0
550574 shard_size = loaded_weight .shape [0 ]
551575 shard_index = ["q" , "k" , "v" ].index (loaded_shard_id )
552576 param_data = param_data .narrow (0 , shard_index * shard_size ,
553577 shard_size )
578+ # Special case for Fp8 scales.
579+ elif fp8_scales_shard_indexer is not None :
580+ param_data , loaded_weight = fp8_scales_shard_indexer (
581+ param_data , loaded_weight , loaded_shard_id )
554582 else :
555583 ignore_warning = getattr (param , "ignore_warning" , False )
556584 if not ignore_warning :
@@ -642,6 +670,10 @@ def __init__(
642670 self .register_parameter ("bias" , None )
643671
644672 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
673+ # Special case for Fp8 scales.
674+ fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
675+ None )
676+
645677 tp_rank = get_tensor_model_parallel_rank ()
646678 input_dim = getattr (param , "input_dim" , None )
647679 param_data = param .data
@@ -650,6 +682,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
650682 start_idx = tp_rank * shard_size
651683 loaded_weight = loaded_weight .narrow (input_dim , start_idx ,
652684 shard_size )
685+ # Special case for Fp8 scales.
686+ elif fp8_scales_shard_indexer is not None :
687+ param_data , loaded_weight = fp8_scales_shard_indexer (param_data ,
688+ loaded_weight ,
689+ shard_id = 0 )
690+
653691 assert param_data .shape == loaded_weight .shape
654692 param_data .copy_ (loaded_weight )
655693
0 commit comments