-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[Seq2Seq Generation] Call encoder before expanding input_ids #3370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Like the change a lot! I'd just propose some renaming. |
src/transformers/modeling_bart.py
Outdated
| config_class = BartConfig | ||
| base_model_prefix = "model" | ||
| pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP | ||
| encoder_outputs_batch_idx = 1 # outputs shaped (bs, ...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would change the name to encoder_outputs_batch_dim_idx
src/transformers/modeling_t5.py
Outdated
| pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP | ||
| load_tf_weights = load_tf_weights_in_t5 | ||
| base_model_prefix = "transformer" | ||
| encoder_outputs_batch_idx = 0 # outputs shaped (bs, ...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would change the name to encoder_outputs_batch_dim_idx
src/transformers/modeling_utils.py
Outdated
| device=next(self.parameters()).device, | ||
| ) | ||
| cur_len = 1 | ||
| batch_idx = self.encoder_outputs_batch_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here batch_dim_idx
src/transformers/modeling_utils.py
Outdated
| assert ( | ||
| batch_size == encoder_outputs[0].shape[batch_idx] | ||
| ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} " | ||
| expanded_index = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe also expanded_idx because we always use idx in the function?
thomwolf
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I agree with the name changes proposed by @patrickvonplaten
Proposing to call model.encoder before expanding
input_idstoeffective_batch_size*num_beams.For Bart, this saves 1.5 GB of GPU mem on batch_size=6. Savings probably similar for T5 (untested).
Requires knowing which index of the encoder_outputs is associated with the batch dim (we need to expand this dimension), which is different between
BartandT5. This difference is encoded in theself.encoder_outputs_batch_idxvariable.This PR is WIP because
encoder_outputs_batch_idxcould be avoided if we transposed Bart's encoder_outputs, which I haven't tried.