@@ -193,19 +193,33 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
193193
194194
195195 def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
196+ """Forward pass of the model.
197+
198+ Args:
199+ idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
200+ Indices of input sequence tokens in the vocabulary.
201+ input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
202+ Indices of positions of each input sequence tokens in the position embeddings.
203+ This argument is optional for training mode but required for
204+ inference mode(when model.setup_caches(training=False) is used).
205+
206+ Returns:
207+ Tensor: The output logits tensor.
208+ """
196209 assert self .freqs_cis is not None , "Caches must be initialized first"
197210
198211 if input_pos is None :
199212 mask = None
200213 freqs_cis = self .freqs_cis [:idx .shape [1 ]]
201- elif not self .linear_causal_mask :
202- mask = self .causal_mask [None , None , input_pos ]
203- elif len (input_pos )> 1 and self .linear_causal_mask : # prefill for linear causal mask
204- mask = torch .tril (torch .ones (len (input_pos ), self .max_seq_length , dtype = torch .bool , device = input_pos .device )).unsqueeze (0 ).unsqueeze (0 )
205- else : # decode_one_token for linear causal mask
206- self .causal_mask [0 ,0 ,0 ,input_pos ] = 1
207- mask = self .causal_mask
208- freqs_cis = self .freqs_cis [input_pos ]
214+ else :
215+ if not self .linear_causal_mask :
216+ mask = self .causal_mask [None , None , input_pos ]
217+ elif len (input_pos )> 1 and self .linear_causal_mask : # prefill for linear causal mask
218+ mask = torch .tril (torch .ones (len (input_pos ), self .max_seq_length , dtype = torch .bool , device = input_pos .device )).unsqueeze (0 ).unsqueeze (0 )
219+ else : # decode_one_token for linear causal mask
220+ self .causal_mask [0 ,0 ,0 ,input_pos ] = 1
221+ mask = self .causal_mask
222+ freqs_cis = self .freqs_cis [input_pos ]
209223
210224 x = self .tok_embeddings (idx )
211225
0 commit comments