@@ -41,13 +41,15 @@ def __init__(
4141 model : nn .Module ,
4242 tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
4343 max_seq_length : Optional [int ] = None ,
44+ use_kv_cache : bool = False ,
4445 ):
4546 device = "cuda" if torch .cuda .is_available () else "cpu"
4647 super ().__init__ (device = device )
4748 self ._model = model
4849 self ._tokenizer = tokenizer
4950 self ._device = torch .device (device )
5051 self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52+ self ._use_kv_cache = use_kv_cache
5153
5254 @property
5355 def eot_token_id (self ):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
8385 return decoded
8486
8587 def _model_call (self , inps ):
86- return self ._model (inps )
88+ if self ._use_kv_cache :
89+ result_logits = []
90+ for pos in range (self ._max_seq_length ):
91+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
92+ logits = self ._model (inps [:, pos : pos + 1 ], pos_tensor )
93+ result_logits .append (logits )
94+ return torch .cat (result_logits , dim = 1 )
95+ else :
96+ return self ._model (inps )
8797
8898 def _model_generate (self , context , max_length , eos_token_id ):
8999 raise Exception ("unimplemented" )
@@ -107,13 +117,22 @@ def __init__(
107117 from executorch .extension .pybindings .portable_lib import _load_for_executorch
108118
109119 self ._et_model = _load_for_executorch (self ._model )
120+ self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ]
110121
111122 def _model_call (self , inps ):
112123 # Given inps (tokens), return the logits from a single forward call
113124 # inps: Tensor of shape (1, max_seq_len - 1)
114- # logits: Tensor of shape (1, max_seq_len - 1, 32000)
115- result = self ._et_model .forward ((inps ,))
116- return result [0 ]
125+ # logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126+ if self ._use_kv_cache :
127+ result_logits = []
128+ for pos in range (self ._max_seq_length ):
129+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
130+ logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
131+ result_logits .append (logits [0 ])
132+ return torch .cat (result_logits , dim = 1 )
133+ else :
134+ result = self ._et_model .forward ((inps ,))
135+ return result [0 ]
117136
118137
119138class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139158
140159 # Example:
141160 # inps: Tensor of shape (1, N)
142- # logits: Tensor of shape (1, N, 32000 )
161+ # logits: Tensor of shape (1, N, vocab_size )
143162 pass
144163
145164
@@ -212,6 +231,7 @@ def gen_eval_wrapper(
212231 # Exported model takes at most (max_seq_length - 1) tokens.
213232 # Note that the eager model takes at most max_seq_length tokens.
214233 max_seq_length = args .max_seq_length - 1 ,
234+ use_kv_cache = args .use_kv_cache ,
215235 )
216236
217237 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
@@ -225,6 +245,7 @@ def gen_eval_wrapper(
225245 model = model ,
226246 tokenizer = tokenizer ,
227247 max_seq_length = args .max_seq_length ,
248+ use_kv_cache = args .use_kv_cache ,
228249 )
229250
230251
0 commit comments