@@ -858,8 +858,12 @@ def update(
858858 k_out = self .key_cache [layer_idx ]
859859 v_out = self .value_cache [layer_idx ]
860860
861- k_out [:, :, cache_position ] = key_states
862- v_out [:, :, cache_position ] = value_states
861+ if cache_position is None :
862+ k_out .copy_ (key_states )
863+ v_out .copy_ (value_states )
864+ else :
865+ k_out [:, :, cache_position ] = key_states
866+ v_out [:, :, cache_position ] = value_states
863867
864868 return k_out , v_out
865869
@@ -971,6 +975,158 @@ def get_max_length(self) -> Optional[int]:
971975 # no matter how long the sentence is
972976 return None
973977
978+ def reset (self ):
979+ self .key_cache .zero_ ()
980+ self .value_cache .zero_ ()
981+
982+
983+ class EncoderDecoderCache (Cache ):
984+ """
985+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
986+ cross-attention caches.
987+ """
988+
989+ def __init__ (self , self_attention_cache : Cache , cross_attention_cache : Cache ):
990+ self .self_attention_cache = self_attention_cache
991+ self .cross_attention_cache = cross_attention_cache
992+
993+ self .is_updated = {}
994+ for layer_idx in range (len (cross_attention_cache .key_cache )):
995+ self .is_updated [layer_idx ] = bool (cross_attention_cache .get_seq_length (layer_idx ) > 0 )
996+
997+ def __getitem__ (self , layer_idx : int ) -> List [Tuple [torch .Tensor ]]:
998+ """
999+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
1000+ sequence length.
1001+ """
1002+ if layer_idx < len (self ):
1003+ return (
1004+ self .self_attention_cache .key_cache [layer_idx ],
1005+ self .self_attention_cache .value_cache [layer_idx ],
1006+ self .cross_attention_cache .key_cache [layer_idx ],
1007+ self .cross_attention_cache .key_cache [layer_idx ],
1008+ )
1009+ else :
1010+ raise KeyError (f"Cache only has { len (self )} layers, attempted to access layer with index { layer_idx } " )
1011+
1012+ def __len__ (self ):
1013+ """
1014+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
1015+ to the number of layers in the model.
1016+ """
1017+ return len (self .self_attention_cache )
1018+
1019+ def to_legacy_cache (self ) -> Tuple [Tuple [torch .Tensor ], Tuple [torch .Tensor ]]:
1020+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
1021+ legacy_cache = ()
1022+ if len (self .cross_attention_cache ) > 0 :
1023+ for self_attn , cross_attn in zip (
1024+ self .self_attention_cache .to_legacy_cache (), self .cross_attention_cache .to_legacy_cache ()
1025+ ):
1026+ legacy_cache += (self_attn + cross_attn ,)
1027+ else :
1028+ legacy_cache = self .self_attention_cache .to_legacy_cache ()
1029+ return legacy_cache
1030+
1031+ @classmethod
1032+ def from_legacy_cache (
1033+ cls , past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
1034+ ) -> "EncoderDecoderCache" :
1035+ """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1036+ cache = cls (self_attention_cache = DynamicCache (), cross_attention_cache = DynamicCache ())
1037+ if past_key_values is not None :
1038+ for layer_idx in range (len (past_key_values )):
1039+ key_states , value_states = past_key_values [layer_idx ][:2 ]
1040+ cache .self_attention_cache .update (key_states , value_states , layer_idx )
1041+ if len (past_key_values [layer_idx ]) > 2 :
1042+ key_states , value_states = past_key_values [layer_idx ][2 :]
1043+ cache .cross_attention_cache .update (key_states , value_states , layer_idx )
1044+ cache .is_updated [layer_idx ] = True
1045+ return cache
1046+
1047+ def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
1048+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
1049+ if len (self .self_attention_cache .key_cache ) <= layer_idx :
1050+ return 0
1051+ return (self .self_attention_cache .key_cache [layer_idx ][0 , 0 ].any (dim = - 1 )).sum ()
1052+
1053+ def reset (self ):
1054+ if hasattr (self .self_attention_cache , "reset" ):
1055+ self .self_attention_cache .reset ()
1056+ if hasattr (self .cross_attention_cache , "reset" ):
1057+ self .cross_attention_cache .reset ()
1058+ elif not hasattr (self .self_attention_cache , "reset" ) and not hasattr (self .cross_attention_cache , "reset" ):
1059+ raise ValueError (
1060+ "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
1061+ "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
1062+ f"Got { self .self_attention_cache .__str__ ()} for the self attention cache and "
1063+ f"{ self .cross_attention_cache .__str__ ()} for the cross attention cache."
1064+ )
1065+ for layer_idx in self .is_updated :
1066+ self .is_updated [layer_idx ] = False
1067+
1068+ def reorder_cache (self , beam_idx : torch .LongTensor ):
1069+ """Reorders the cache for beam search, given the selected beam indices."""
1070+ self .self_attention_cache .reorder_cache (beam_idx )
1071+ self .cross_attention_cache .reorder_cache (beam_idx )
1072+
1073+ def check_dynamic_cache (self , method : str ):
1074+ if not (
1075+ isinstance (self .self_attention_cache , DynamicCache )
1076+ and isinstance (self .cross_attention_cache , DynamicCache )
1077+ ):
1078+ raise ValueError (
1079+ f"`{ method } ` is only defined for dynamic cache, got { self .self_attention_cache .__str__ ()} for the self "
1080+ f"attention cache and { self .cross_attention_cache .__str__ ()} for the cross attention cache."
1081+ )
1082+
1083+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
1084+ def crop (self , maximum_length : int ):
1085+ """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
1086+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
1087+ self .check_dynamic_cache (self .crop .__name__ )
1088+ self .self_attention_cache .crop (maximum_length )
1089+
1090+ def batch_split (self , full_batch_size : int , split_size : int ) -> "List[EncoderDecoderCache]" :
1091+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
1092+ `_split_model_inputs()` in `generation.utils`"""
1093+ self .check_dynamic_cache (self .batch_split .__name__ )
1094+ self_attention_cache = self .self_attention_cache .batch_split (full_batch_size , split_size )
1095+ cross_attention_cache = self .cross_attention_cache .batch_split (full_batch_size , split_size )
1096+
1097+ out = []
1098+ for self_attn , cross_attn in zip (self_attention_cache , cross_attention_cache ):
1099+ out .append (EncoderDecoderCache (self_attn , cross_attn ))
1100+ return out
1101+
1102+ @classmethod
1103+ def from_batch_splits (cls , splits : List ["EncoderDecoderCache" ]) -> "EncoderDecoderCache" :
1104+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
1105+ `generation.utils`"""
1106+ self_attention_cache = DynamicCache ()
1107+ cross_attention_cache = DynamicCache ()
1108+ for idx in range (len (splits [0 ])):
1109+ layer_keys = torch .cat ([current .self_attention_cache .key_cache [idx ] for current in splits ], dim = 0 )
1110+ layer_values = torch .cat ([current .self_attention_cache .value_cache [idx ] for current in splits ], dim = 0 )
1111+ self_attention_cache .update (layer_keys , layer_values , idx )
1112+
1113+ layer_keys = torch .cat ([current .cross_attention_cache .key_cache [idx ] for current in splits ], dim = 0 )
1114+ layer_values = torch .cat ([current .cross_attention_cache .value_cache [idx ] for current in splits ], dim = 0 )
1115+ cross_attention_cache .update (layer_keys , layer_values , idx )
1116+ return cls (self_attention_cache , cross_attention_cache )
1117+
1118+ def batch_repeat_interleave (self , repeats : int ):
1119+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
1120+ self .check_dynamic_cache (self .batch_repeat_interleave .__name__ )
1121+ self .self_attention_cache .batch_repeat_interleave (repeats )
1122+ self .cross_attention_cache .batch_repeat_interleave (repeats )
1123+
1124+ def batch_select_indices (self , indices : torch .Tensor ):
1125+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
1126+ self .check_dynamic_cache (self .batch_select_indices .__name__ )
1127+ self .self_attention_cache .batch_select_indices (indices )
1128+ self .cross_attention_cache .batch_select_indices (indices )
1129+
9741130
9751131class HybridCache (Cache ):
9761132 def __init__ (self , config : PretrainedConfig , max_batch_size , max_cache_len , device = "cpu" , dtype = None ) -> None :
0 commit comments