Skip to content

Commit fa419d4

Browse files
committed
1 parent 6c92ce3 commit fa419d4

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/modeling_encoder_decoder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020

2121
from torch import nn
22-
22+
from .configuration_auto import AutoConfig
2323
from .modeling_auto import AutoModel, AutoModelWithLMHead
2424

2525

@@ -109,7 +109,7 @@ def from_pretrained(
109109
Examples::
110110
111111
# For example purposes. Not runnable.
112-
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
112+
model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
113113
"""
114114

115115
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
@@ -148,8 +148,10 @@ def from_pretrained(
148148

149149
decoder = kwargs_decoder.pop("model", None)
150150
if decoder is None:
151+
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
152+
decoder_config.is_decoder = True
153+
kwargs_decoder["config"] = decoder_config
151154
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
152-
decoder.config.is_decoder = True
153155

154156
model = cls(encoder, decoder)
155157

0 commit comments

Comments
 (0)