1919# since the model classes inherit torch.nn.Module. 
2020import  math 
2121
22+ import  numba 
2223import  numpy  as  np 
2324import  torch 
2425from  torch .autograd  import  Function 
2526from  torch .nn  import  functional  as  F 
26- import  numba 
2727
2828from  neural_compressor .torch .utils  import  accelerator , logger 
2929
@@ -301,11 +301,11 @@ def unpack_tensor_with_torch(self, packed_tensor):
301301                unpacked_tensor [:, index ].copy_ (tmp .type (target_dtype ))
302302                accelerator .synchronize ()
303303        return  unpacked_tensor 
304-      
304+ 
305305    @staticmethod  
306306    @numba .jit (nopython = True , parallel = True ) 
307307    def  pack_array_with_numba_b4_c32 (
308-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
308+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
309309    ) ->  np .ndarray :
310310        for  i  in  range (new_in_features ):
311311            packed_array [:, i ] =  (
@@ -319,11 +319,11 @@ def pack_array_with_numba_b4_c32(
319319                |  (raw_array [:, i  *  n_pack ] &  0b1111 )
320320            )
321321        return  packed_array 
322-      
322+ 
323323    @staticmethod  
324324    @numba .jit (nopython = True , parallel = True ) 
325325    def  pack_array_with_numba_b4_c16 (
326-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
326+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
327327    ) ->  np .ndarray :
328328        for  i  in  range (new_in_features ):
329329            packed_array [:, i ] =  (
@@ -333,23 +333,20 @@ def pack_array_with_numba_b4_c16(
333333                |  (raw_array [:, i  *  n_pack ] &  0b1111 )
334334            )
335335        return  packed_array 
336-      
336+ 
337337    @staticmethod  
338338    @numba .jit (nopython = True , parallel = True ) 
339339    def  pack_array_with_numba_b4_c8 (
340-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
340+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
341341    ) ->  np .ndarray :
342342        for  i  in  range (new_in_features ):
343-             packed_array [:, i ] =  (
344-                 ((raw_array [:, i  *  n_pack  +  1 ] &  0b1111 ) <<  4 )
345-                 |  (raw_array [:, i  *  n_pack ] &  0b1111 )
346-             )
343+             packed_array [:, i ] =  ((raw_array [:, i  *  n_pack  +  1 ] &  0b1111 ) <<  4 ) |  (raw_array [:, i  *  n_pack ] &  0b1111 )
347344        return  packed_array 
348-      
345+ 
349346    @staticmethod  
350347    @numba .jit (nopython = True , parallel = True ) 
351348    def  pack_array_with_numba_b4_c64 (
352-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
349+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
353350    ) ->  np .ndarray :
354351        for  i  in  range (new_in_features ):
355352            packed_array [:, i ] =  (
@@ -372,11 +369,10 @@ def pack_array_with_numba_b4_c64(
372369            )
373370        return  packed_array 
374371
375-     
376372    @staticmethod  
377373    @numba .jit (nopython = True , parallel = True ) 
378374    def  pack_array_with_numba_b8_c32 (
379-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
375+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
380376    ) ->  np .ndarray :
381377        for  i  in  range (new_in_features ):
382378            packed_array [:, i ] =  (
@@ -386,11 +382,11 @@ def pack_array_with_numba_b8_c32(
386382                |  (raw_array [:, i  *  n_pack ] &  0b11111111 )
387383            )
388384        return  packed_array 
389-      
385+ 
390386    @staticmethod  
391387    @numba .jit (nopython = True , parallel = True ) 
392388    def  pack_array_with_numba_b8_c16 (
393-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
389+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
394390    ) ->  np .ndarray :
395391        for  i  in  range (new_in_features ):
396392            packed_array [:, i ] =  (
@@ -400,20 +396,20 @@ def pack_array_with_numba_b8_c16(
400396                |  (raw_array [:, i  *  n_pack ] &  0b11111111 )
401397            )
402398        return  packed_array 
403-      
399+ 
404400    @staticmethod  
405401    @numba .jit (nopython = True , parallel = True ) 
406402    def  pack_array_with_numba_b8_c8 (
407-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
403+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
408404    ) ->  np .ndarray :
409405        for  i  in  range (new_in_features ):
410-             packed_array [:, i ] =  ( raw_array [:, i  *  n_pack ] &  0b11111111 ) 
406+             packed_array [:, i ] =  raw_array [:, i  *  n_pack ] &  0b11111111 
411407        return  packed_array 
412-      
408+ 
413409    @staticmethod  
414410    @numba .jit (nopython = True , parallel = True ) 
415411    def  pack_array_with_numba_b8_c64 (
416-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
412+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
417413    ) ->  np .ndarray :
418414        for  i  in  range (new_in_features ):
419415            packed_array [:, i ] =  (
@@ -427,11 +423,11 @@ def pack_array_with_numba_b8_c64(
427423                |  (raw_array [:, i  *  n_pack ] &  0b11111111 )
428424            )
429425        return  packed_array 
430-      
426+ 
431427    @staticmethod  
432428    @numba .jit (nopython = True , parallel = True ) 
433429    def  pack_array_with_numba_b2_c32 (
434-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
430+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
435431    ) ->  np .ndarray :
436432        for  i  in  range (new_in_features ):
437433            packed_array [:, i ] =  (
@@ -457,7 +453,7 @@ def pack_array_with_numba_b2_c32(
457453    @staticmethod  
458454    @numba .jit (nopython = True , parallel = True ) 
459455    def  pack_array_with_numba_b2_c16 (
460-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
456+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
461457    ) ->  np .ndarray :
462458        for  i  in  range (new_in_features ):
463459            packed_array [:, i ] =  (
@@ -471,11 +467,11 @@ def pack_array_with_numba_b2_c16(
471467                |  (raw_array [:, i  *  n_pack ] &  0b11 )
472468            )
473469        return  packed_array 
474-      
470+ 
475471    @staticmethod  
476472    @numba .jit (nopython = True , parallel = True ) 
477473    def  pack_array_with_numba_b2_c8 (
478-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
474+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
479475    ) ->  np .ndarray :
480476        for  i  in  range (new_in_features ):
481477            packed_array [:, i ] =  (
@@ -485,11 +481,11 @@ def pack_array_with_numba_b2_c8(
485481                |  (raw_array [:, i  *  n_pack ] &  0b11 )
486482            )
487483        return  packed_array 
488-      
484+ 
489485    @staticmethod  
490486    @numba .jit (nopython = True , parallel = True ) 
491487    def  pack_array_with_numba_b2_c64 (
492-         raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int 
488+         raw_array : np .ndarray , packed_array :  np .ndarray , n_pack : int , new_in_features :  int 
493489    ) ->  np .ndarray :
494490        for  i  in  range (new_in_features ):
495491            packed_array [:, i ] =  (
@@ -527,7 +523,7 @@ def pack_array_with_numba_b2_c64(
527523                |  (raw_array [:, i  *  n_pack ] &  0b11 )
528524            )
529525        return  packed_array 
530-      
526+ 
531527    def  pack_array_with_numba (
532528        self , raw_array : np .ndarray , n_pack : int , bits : int , compress_bits : int , compression_dtype = np .int32 
533529    ) ->  np .ndarray :
@@ -547,17 +543,18 @@ def pack_array_with_numba(
547543        new_in_features  =  (in_features  +  n_pack  -  1 ) //  n_pack 
548544        packed_array  =  np .zeros ((out_features , new_in_features ), dtype = compression_dtype )
549545        raw_array  =  raw_array .astype (compression_dtype )
550-          
546+ 
551547        pack_method_name  =  f"pack_array_with_numba_b{ bits }  _c{ compress_bits }  " 
552548        pack_method  =  getattr (self , pack_method_name )
553549        return  pack_method (raw_array , packed_array , n_pack , new_in_features )
554-          
550+ 
555551    @staticmethod  
556552    @numba .jit (nopython = True ) 
557553    def  pack_array_with_numba_yi (
558554        raw_tensor : np .ndarray , n_pack : int , bits : int , compression_dtype = np .int32 
559555    ) ->  np .ndarray :
560556        """Packs the input tensor by combining elements into a specified bit-width format using NumPy. 
557+ 
561558        Args: 
562559            raw_tensor (np.ndarray): The tensor to be packed. Shape: [out_features, in_features] or [1, in_features]. 
563560            n_pack (int): The number of elements to be packed together. 
@@ -575,7 +572,7 @@ def pack_array_with_numba_yi(
575572            for  i  in  range (new_in_features ):
576573                packed_tensor [:, i ] =  (
577574                    (raw_tensor [:, i  *  n_pack  +  7 ] <<  28 )
578-                     |  (raw_tensor [:, i  *  n_pack  +  6 ]   <<  24 )
575+                     |  (raw_tensor [:, i  *  n_pack  +  6 ] <<  24 )
579576                    |  (raw_tensor [:, i  *  n_pack  +  5 ] <<  20 )
580577                    |  (raw_tensor [:, i  *  n_pack  +  4 ] <<  16 )
581578                    |  (raw_tensor [:, i  *  n_pack  +  3 ] <<  12 )
@@ -585,25 +582,29 @@ def pack_array_with_numba_yi(
585582                )
586583
587584        return  packed_tensor 
588-      
585+ 
589586    def  pack_tensor_with_reshape (self , raw_tensor ):
590587        raw_array  =  raw_tensor .cpu ().numpy ()
591588        target_len  =  np .ceil (raw_array .shape [1 ] /  self .n_pack ).astype (int )
592589        target_dtype  =  torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype 
593590        reshaped  =  raw_array .reshape (- 1 , self .n_pack )
594591        packed_array  =  np .zeros (reshaped .shape [0 ], dtype = target_dtype )
595592        for  i  in  range (self .n_pack ):
596-             packed_array  |=  (reshaped [:, i ].astype (target_dtype ) <<  (self .bits  *  i ))
597-        
598-         packed_tensor  =  torch .from_numpy (packed_array .reshape ((raw_array .shape [0 ], target_len ))).to (device = raw_tensor .device )
593+             packed_array  |=  reshaped [:, i ].astype (target_dtype ) <<  (self .bits  *  i )
594+ 
595+         packed_tensor  =  torch .from_numpy (packed_array .reshape ((raw_array .shape [0 ], target_len ))).to (
596+             device = raw_tensor .device 
597+         )
599598        return  packed_tensor 
600599
601600    def  pack_tensor_with_numpy (self , raw_tensor ):
602601        if  self .bits  not  in   [2 , 4 , 8 ]:
603602            return  self .pack_tensor_with_reshape (raw_tensor )
604603        compression_dtype  =  torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype 
605604        # packed_array = self.pack_array_with_numba_yi(raw_tensor.cpu().numpy(), self.n_pack, self.bits,  compression_dtype) 
606-         packed_array  =  self .pack_array_with_numba (raw_tensor .cpu ().numpy (), self .n_pack , self .bits , self .compress_bits , compression_dtype )
605+         packed_array  =  self .pack_array_with_numba (
606+             raw_tensor .cpu ().numpy (), self .n_pack , self .bits , self .compress_bits , compression_dtype 
607+         )
607608        return  torch .from_numpy (packed_array ).to (device = raw_tensor .device )
608609
609610    def  unpack_tensor_with_numpy (self , packed_tensor ):
0 commit comments