@@ -895,6 +895,21 @@ def generate(
895895 effective_batch_size = batch_size
896896 effective_batch_mult = 1
897897
898+ if self .config .is_encoder_decoder :
899+ if decoder_start_token_id is None :
900+ decoder_start_token_id = bos_token_id
901+
902+ assert (
903+ decoder_start_token_id is not None
904+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
905+ assert hasattr (self , "get_encoder" ), "{} should have a 'get_encoder' function defined" .format (self )
906+ assert callable (self .get_encoder ), "{} should be a method" .format (self .get_encoder )
907+
908+ # get encoder and store encoder outputs
909+ encoder = self .get_encoder ()
910+
911+ encoder_outputs = encoder (input_ids , attention_mask = attention_mask )
912+
898913 # Expand input ids if num_beams > 1 or num_return_sequences > 1
899914 if num_return_sequences > 1 or num_beams > 1 :
900915 input_ids_len = input_ids .shape [- 1 ]
@@ -911,20 +926,6 @@ def generate(
911926 ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
912927
913928 if self .config .is_encoder_decoder :
914- if decoder_start_token_id is None :
915- decoder_start_token_id = bos_token_id
916-
917- assert (
918- decoder_start_token_id is not None
919- ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
920- assert hasattr (self , "get_encoder" ), "{} should have a 'get_encoder' function defined" .format (self )
921- assert callable (self .get_encoder ), "{} should be a method" .format (self .get_encoder )
922-
923- # get encoder and store encoder outputs
924- encoder = self .get_encoder ()
925-
926- encoder_outputs = encoder (input_ids , attention_mask = attention_mask )
927-
928929 # create empty decoder_input_ids
929930 input_ids = torch .full (
930931 (effective_batch_size * num_beams , 1 ),
@@ -933,6 +934,18 @@ def generate(
933934 device = next (self .parameters ()).device ,
934935 )
935936 cur_len = 1
937+ batch_idx = self .encoder_outputs_batch_dim_idx
938+ assert (
939+ batch_size == encoder_outputs [0 ].shape [batch_idx ]
940+ ), f"expected encoder_outputs[0] to have 1st dimension bs={ batch_size } , got { encoder_outputs [0 ].shape [1 ]} "
941+ expanded_idx = (
942+ torch .arange (batch_size )
943+ .view (- 1 , 1 )
944+ .repeat (1 , num_beams * effective_batch_mult )
945+ .view (- 1 )
946+ .to (input_ids .device )
947+ )
948+ encoder_outputs = (encoder_outputs [0 ].index_select (batch_idx , expanded_idx ), * encoder_outputs [1 :])
936949 else :
937950 encoder_outputs = None
938951 cur_len = input_ids .shape [- 1 ]
0 commit comments