1515import torch
1616import torch .nn as nn
1717import torch .nn .functional as F
18- from build .utils import find_multiple , get_precision
18+ from build .utils import find_multiple , get_precision , use_et_backend
1919
2020
2121#########################################################################
@@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
9292 return self .quantizer .quantize (self .model_ )
9393
9494
95- #########################################################################
96- ### QuantHandler API definition ###
97- ### (unify with torchao in future) ###
98-
99-
100- class QuantHandler :
101- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None ):
102- self .model_ = model
103- self .device = device
104- self .tokenizer = tokenizer
105-
106- def create_quantized_state_dict (self ) -> Dict : # "StateDict"
107- pass
108-
109- def convert_for_runtime (self ) -> nn .Module :
110- pass
111-
112- def quantized_model (self ) -> nn .Module :
113- model_updated_state_dict = self .create_quantized_state_dict ()
114- self .convert_for_runtime ()
115- self .model_ .load_state_dict (model_updated_state_dict )
116- return self .model_
117-
118-
11995#########################################################################
12096### wrapper for setting precision as a QuantHandler ###
12197
@@ -521,7 +497,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
521497
522498
523499def replace_embedding_weight_only_grouped_int8_per_channel (
524- module , device , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False
500+ module , device , bitwidth : int , groupsize : Optional [int ]
525501):
526502 for name , child in module .named_children ():
527503 # print(f"name: {name}")
@@ -535,13 +511,13 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
535511 device = device ,
536512 vocab_size = child .weight .shape [0 ],
537513 embedding_dim = child .weight .shape [1 ],
514+ bitwidth = bitwidth ,
538515 groupsize = groupsize ,
539- packed = packed ,
540516 ),
541517 )
542518 else :
543519 replace_embedding_weight_only_grouped_int8_per_channel (
544- child , device , bitwidth , groupsize , packed
520+ child , device , bitwidth , groupsize
545521 )
546522
547523
@@ -554,19 +530,15 @@ def __init__(
554530 * ,
555531 bitwidth : int = 8 ,
556532 groupsize : Optional [int ] = None ,
557- packed = True ,
533+ packed = True , # we always pack bitwidth 4 now
558534 ):
559- # when quantization dictionary comes from JSON, packed is a string
560- if isinstance (packed , str ):
561- packed = packed .lower () != "false"
562535 self .model_ = model
563536 self .device = device
564537 self .groupsize = groupsize
565538 self .bitwidth = bitwidth
566- self .packed = packed
567539
568540 @torch .no_grad ()
569- def create_quantized_state_dict (self , packed = False ) -> Dict :
541+ def create_quantized_state_dict (self ) -> Dict :
570542 cur_state_dict = self .model_ .state_dict ()
571543
572544 if self .bitwidth == 4 :
@@ -596,7 +568,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
596568 scales_dtype = mod .weight .dtype ,
597569 )
598570
599- if packed :
571+ if self . bitwidth == 4 :
600572 if weight .shape [- 1 ] % 2 != 0 :
601573 raise RuntimeError ("automatic padding not implemented yet" )
602574
@@ -620,12 +592,12 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
620592
621593 def convert_for_runtime (self ) -> nn .Module :
622594 replace_embedding_weight_only_grouped_int8_per_channel (
623- self .model_ , self .device , self .bitwidth , self .groupsize , self . packed
595+ self .model_ , self .device , self .bitwidth , self .groupsize
624596 )
625597 return self .model_
626598
627599 def quantized_model (self ) -> nn .Module :
628- model_updated_state_dict = self .create_quantized_state_dict (self . packed )
600+ model_updated_state_dict = self .create_quantized_state_dict ()
629601 self .convert_for_runtime ()
630602 self .model_ .load_state_dict (model_updated_state_dict )
631603 return self .model_
@@ -637,30 +609,42 @@ def __init__(
637609 device ,
638610 vocab_size : int ,
639611 embedding_dim : int ,
612+ bitwidth : int ,
640613 groupsize : Optional [int ] = None ,
614+ * ,
641615 dtype = torch .half ,
642- packed = False ,
643616 ) -> None :
644617 super ().__init__ ()
645618 if groupsize is None or groupsize == 0 :
646619 groupsize = embedding_dim
647620 self .groupsize = groupsize
648621 self .dtype = dtype
649- self .packed = packed
650- if not packed :
622+ self .bitwidth = bitwidth
623+
624+ if use_et_backend ():
625+ self .forward = self .et_forward
626+ else :
627+ self .forward = self .aoti_forward
628+
629+ if bitwidth == 8 :
651630 self .register_buffer (
652631 "weight" ,
653632 torch .empty (
654633 (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
655634 ),
656635 )
657- else : # packed
636+ elif bitwidth == 4 : # packed
658637 self .register_buffer (
659638 "weight" ,
660639 torch .empty (
661640 (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
662641 ),
663642 )
643+ else :
644+ raise RuntimeError (
645+ f"QUantized embedding does not support bitwidth={ bitwidth } "
646+ )
647+
664648 groups_per_row = (embedding_dim + groupsize - 1 ) // groupsize
665649 if groups_per_row > 1 :
666650 self .register_buffer (
@@ -675,16 +659,22 @@ def __init__(
675659 )
676660
677661 @torch .no_grad ()
678- def forward (self , indices : torch .Tensor ) -> torch .Tensor :
679- if False : # Used for Executorch
680- return torch .ops .llama_quantized .embedding_byte .dtype (
662+ def et_forward (self , indices : torch .Tensor ) -> torch .Tensor :
663+ if self .bitwidth == 8 :
664+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
665+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
666+ )
667+ else :
668+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
681669 self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
682670 )
683671
672+ @torch .no_grad ()
673+ def aoti_forward (self , indices : torch .Tensor ) -> torch .Tensor :
684674 # result_weights = self.weight.index_select(0, indices.view(-1))
685675 # result_scales = self.scales.index_select(0, indices.view(-1))
686676
687- if self .packed :
677+ if self .bitwidth == 4 :
688678 weight_even = self .weight .div (16 , rounding_mode = "trunc" )
689679 weight_odd = self .weight .remainder (16 )
690680 weight_unpacked = torch .stack ((weight_even , weight_odd ), dim = - 1 )
0 commit comments