@@ -1890,16 +1890,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
18901890 qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
18911891
18921892 if re .match (qkv_pattern , name ):
1893- from einops import rearrange
1894-
18951893 bid = re .findall (qkv_pattern , name )[0 ]
18961894 qkv = data_torch
1897- qkv = rearrange (qkv .T , " o (g n i) ->o g n i" , g = num_groups , n = q_per_kv + 2 , i = head_dim )
1895+ # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
1896+ qkv = qkv .T .reshape ((- 1 , num_groups , q_per_kv + 2 , head_dim ))
18981897 q , k , v = qkv [..., : q_per_kv , :], qkv [..., q_per_kv : q_per_kv + 1 , :], qkv [..., q_per_kv + 1 : q_per_kv + 2 , :]
18991898 # The model weights of q and k equire additional reshape.
1900- q = self ._hf_permute_qk (rearrange (q , " o g n i -> o (g n i)" ).T , num_heads , num_heads )
1901- k = self ._hf_permute_qk (rearrange (k , " o g n i -> o (g n i)" ).T , num_heads , num_kv_heads )
1902- v = rearrange (v , " o g n i -> o (g n i)" ).T
1899+ # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
1900+ q = self ._hf_permute_qk (q .reshape ((q .shape [0 ], - 1 )).T , num_heads , num_heads )
1901+ # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
1902+ k = self ._hf_permute_qk (k .reshape ((k .shape [0 ], - 1 )).T , num_heads , num_kv_heads )
1903+ # v = rearrange(v, " o g n i -> o (g n i)").T
1904+ v = v .reshape ((v .shape [0 ], - 1 )).T
19031905 return [
19041906 (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_Q , bid ), q ),
19051907 (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_K , bid ), k ),
@@ -2238,13 +2240,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22382240class LazyTorchTensor :
22392241 _meta : Tensor
22402242 _data : Tensor | None
2241- _args : list [ Any ]
2242- _func : Callable [[list [ Any ]] , Tensor ] | None = None
2243+ _args : tuple
2244+ _func : Callable [[tuple ] , Tensor ] | None
22432245
2244- def __init__ (self , * , meta : Tensor , data : Tensor | None = None , args : list [ Any ] | None = None , func : Callable [[list [ Any ] ], Tensor ] | None = None ):
2246+ def __init__ (self , * , meta : Tensor , data : Tensor | None = None , args : tuple = () , func : Callable [[tuple ], Tensor ] | None = None ):
22452247 self ._meta = meta
22462248 self ._data = data
2247- self ._args = args if args is not None else []
2249+ self ._args = args
22482250 self ._func = func
22492251
22502252 @staticmethod
@@ -2266,19 +2268,22 @@ def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], Lazy
22662268 def wrapped_fn (* args , ** kwargs ):
22672269 if kwargs is None :
22682270 kwargs = {}
2269- args_list = ([ self ] if use_self else []) + list ( args )
2271+ args = (( self ,) if use_self else ()) + args
22702272
2271- meta_args = LazyTorchTensor ._recurse_apply (args_list , lambda t : t ._meta )
2273+ meta_args = LazyTorchTensor ._recurse_apply (args , lambda t : t ._meta )
22722274
2273- return LazyTorchTensor (meta = fn (* meta_args , ** kwargs ), args = args_list , func = lambda a : fn (* a , ** kwargs ))
2275+ return LazyTorchTensor (meta = fn (* meta_args , ** kwargs ), args = args , func = lambda a : fn (* a , ** kwargs ))
22742276 return wrapped_fn
22752277
22762278 def __getattr__ (self , __name : str ) -> Any :
22772279 meta_attr = getattr (self ._meta , __name )
2278- if not callable (meta_attr ):
2279- return meta_attr
2280- else :
2280+ if callable (meta_attr ):
22812281 return self ._wrap_fn (getattr (torch .Tensor , __name ), use_self = True )
2282+ elif isinstance (meta_attr , torch .Tensor ):
2283+ # for things like self.T
2284+ return self ._wrap_fn (lambda s : getattr (s , __name ))(self )
2285+ else :
2286+ return meta_attr
22822287
22832288 _dtype_map : dict [torch .dtype , type ] = {
22842289 torch .float16 : np .float16 ,
@@ -2295,7 +2300,7 @@ def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
22952300
22962301 @overload
22972302 @staticmethod
2298- def to_eager (t : list [ Tensor | LazyTorchTensor ] ) -> list [ Tensor ] : ...
2303+ def to_eager (t : tuple ) -> tuple : ...
22992304
23002305 @staticmethod
23012306 def to_eager (t : Any ) -> Any :
0 commit comments