Skip to content

Commit 1a5aefc

Browse files
authored
[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)
1 parent 39371ee commit 1a5aefc

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

src/transformers/modeling_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
113113
config_class = BartConfig
114114
base_model_prefix = "model"
115115
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
116+
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
116117

117118
def _init_weights(self, module):
118119
std = self.config.init_std
@@ -888,7 +889,6 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask,
888889
encoder_outputs, decoder_cached_states = past, None
889890
else:
890891
encoder_outputs, decoder_cached_states = past
891-
892892
return {
893893
"input_ids": None, # encoder_outputs is defined. input_ids not needed
894894
"encoder_outputs": encoder_outputs,

src/transformers/modeling_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
457457
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
458458
load_tf_weights = load_tf_weights_in_t5
459459
base_model_prefix = "transformer"
460+
encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...)
460461

461462
@property
462463
def dummy_inputs(self):

src/transformers/modeling_utils.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,21 @@ def generate(
895895
effective_batch_size = batch_size
896896
effective_batch_mult = 1
897897

898+
if self.config.is_encoder_decoder:
899+
if decoder_start_token_id is None:
900+
decoder_start_token_id = bos_token_id
901+
902+
assert (
903+
decoder_start_token_id is not None
904+
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
905+
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
906+
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
907+
908+
# get encoder and store encoder outputs
909+
encoder = self.get_encoder()
910+
911+
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
912+
898913
# Expand input ids if num_beams > 1 or num_return_sequences > 1
899914
if num_return_sequences > 1 or num_beams > 1:
900915
input_ids_len = input_ids.shape[-1]
@@ -911,20 +926,6 @@ def generate(
911926
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
912927

913928
if self.config.is_encoder_decoder:
914-
if decoder_start_token_id is None:
915-
decoder_start_token_id = bos_token_id
916-
917-
assert (
918-
decoder_start_token_id is not None
919-
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
920-
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
921-
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
922-
923-
# get encoder and store encoder outputs
924-
encoder = self.get_encoder()
925-
926-
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
927-
928929
# create empty decoder_input_ids
929930
input_ids = torch.full(
930931
(effective_batch_size * num_beams, 1),
@@ -933,6 +934,18 @@ def generate(
933934
device=next(self.parameters()).device,
934935
)
935936
cur_len = 1
937+
batch_idx = self.encoder_outputs_batch_dim_idx
938+
assert (
939+
batch_size == encoder_outputs[0].shape[batch_idx]
940+
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
941+
expanded_idx = (
942+
torch.arange(batch_size)
943+
.view(-1, 1)
944+
.repeat(1, num_beams * effective_batch_mult)
945+
.view(-1)
946+
.to(input_ids.device)
947+
)
948+
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
936949
else:
937950
encoder_outputs = None
938951
cur_len = input_ids.shape[-1]

0 commit comments

Comments
 (0)