@@ -525,14 +525,14 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
525525 return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
526526 return k_divisible_by_groupsize
527527
528- def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
528+ def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize , dtype = torch . bfloat16 ):
529529 origin_x_size = x .size ()
530530 x = x .reshape (- 1 , origin_x_size [- 1 ])
531531 c = torch .ops .aten ._weight_int4pack_mm (
532- x .to (torch . bfloat16 ),
532+ x .to (dtype ),
533533 weight_int4pack ,
534534 groupsize ,
535- scales_and_zeros .to (torch . bfloat16 )
535+ scales_and_zeros .to (dtype )
536536 ).to (dtype = x .dtype )
537537 new_shape = origin_x_size [:- 1 ] + (out_features ,)
538538 c = c .reshape (new_shape )
@@ -546,12 +546,12 @@ class WeightOnlyInt4Linear(torch.nn.Module):
546546
547547 def __init__ (
548548 self , in_features : int , out_features : int ,
549- bias = False , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 ,
549+ bias = False , device = None , dtype = torch . bfloat16 , groupsize : int = 128 , inner_k_tiles : int = 8 ,
550550 ) -> None :
551551 super ().__init__ ()
552552 self .padding = not _check_linear_int4_k (in_features , groupsize , inner_k_tiles )
553553 if self .padding :
554- from model import find_multiple
554+ from . utils import find_multiple
555555 self .origin_in_features = in_features
556556 in_features = find_multiple (in_features , 1024 )
557557
@@ -567,9 +567,10 @@ def __init__(
567567 "weight" ,
568568 torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
569569 )
570+ self .dtype = dtype
570571 self .register_buffer (
571572 "scales_and_zeros" ,
572- torch .empty ((in_features // groupsize , out_features , 2 ), dtype = torch . bfloat16 )
573+ torch .empty ((in_features // groupsize , out_features , 2 ), dtype = self . dtype )
573574 )
574575
575576 def forward (self , input : torch .Tensor ) -> torch .Tensor :
@@ -578,20 +579,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
578579 input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
579580 return linear_forward_int4 (
580581 input ,
581- self .weight , self .scales_and_zeros , self .out_features , self .groupsize
582+ self .weight , self .scales_and_zeros , self .out_features , self .groupsize , self . dtype
582583 )
583584
584- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , skip_layer_func = None ):
585+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , skip_layer_func = None , dtype = torch . bfloat16 ):
585586
586587 for name , child in module .named_children ():
587588 if isinstance (child , nn .Linear ) and (skip_layer_func is None or not skip_layer_func (child .weight )):
588589 if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
589590 setattr (module , name , WeightOnlyInt4Linear (
590591 child .in_features , child .out_features , bias = False ,
591592 groupsize = groupsize , inner_k_tiles = inner_k_tiles ,
593+ dtype = dtype ,
592594 ))
593595 else :
594- replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , skip_layer_func )
596+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , skip_layer_func , dtype )
595597
596598class Int4WeightOnlyQuantizer (Quantizer ):
597599 def __init__ (
@@ -600,6 +602,7 @@ def __init__(
600602 padding_allowed : bool = True ,
601603 inner_k_tiles : Optional [int ] = 8 ,
602604 device : torch .device = torch .device ("cuda" ),
605+ precision : torch .dtype = torch .bfloat16 ,
603606 ) -> None :
604607 super ().__init__ ()
605608 assert inner_k_tiles in [2 , 4 , 8 ]
@@ -609,6 +612,8 @@ def __init__(
609612 self .groupsize : int = groupsize
610613 self .padding_allowed : bool = padding_allowed
611614 self .device : torch .device = device
615+ # precision and dtype are being used interchangeably here
616+ self .precision : torch .dtype = precision
612617
613618 @torch .no_grad ()
614619 def _create_quantized_state_dict (
@@ -648,6 +653,7 @@ def _create_quantized_state_dict(
648653 weight ,
649654 4 , # n_bit
650655 self .groupsize ,
656+ self .precision , # dtype for scales_and_zeros
651657 )
652658 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (w_int4x8 .to (self .device ), self .inner_k_tiles )
653659 cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to (self .device )
@@ -660,6 +666,8 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
660666 self .groupsize ,
661667 self .inner_k_tiles ,
662668 self .padding_allowed ,
669+ skip_layer_func = None ,
670+ dtype = self .precision ,
663671 )
664672 return model
665673
0 commit comments