Skip to content

Conversation

@sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Mar 21, 2020

Proposing to call model.encoder before expanding input_ids to effective_batch_size*num_beams.

For Bart, this saves 1.5 GB of GPU mem on batch_size=6. Savings probably similar for T5 (untested).

Requires knowing which index of the encoder_outputs is associated with the batch dim (we need to expand this dimension), which is different between Bart and T5. This difference is encoded in the self.encoder_outputs_batch_idx variable.

This PR is WIP because encoder_outputs_batch_idx could be avoided if we transposed Bart's encoder_outputs, which I haven't tried.

@sshleifer sshleifer changed the title [Generation/WIP] Call encoder earlier [Generation/WIP] Call encoder before expanding input_ids Mar 21, 2020
@sshleifer sshleifer changed the title [Generation/WIP] Call encoder before expanding input_ids [Seq2Seq Generation] Call encoder before expanding input_ids Mar 22, 2020
@sshleifer sshleifer marked this pull request as ready for review March 22, 2020 16:11
@sshleifer sshleifer requested review from julien-c and thomwolf March 22, 2020 16:27
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 22, 2020

Like the change a lot!
One question I asked myself: With this change the encoder_outputs which are the same point to the same memory address -> could that lead to problems? Probably not because the encoder_outputs are never changed, right?

I'd just propose some renaming.

config_class = BartConfig
base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_idx = 1 # outputs shaped (bs, ...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would change the name to encoder_outputs_batch_dim_idx

pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
encoder_outputs_batch_idx = 0 # outputs shaped (bs, ...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would change the name to encoder_outputs_batch_dim_idx

device=next(self.parameters()).device,
)
cur_len = 1
batch_idx = self.encoder_outputs_batch_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here batch_dim_idx

assert (
batch_size == encoder_outputs[0].shape[batch_idx]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
expanded_index = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also expanded_idx because we always use idx in the function?

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
I agree with the name changes proposed by @patrickvonplaten

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants