@@ -128,12 +128,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
128128 self .offset = 2
129129 super ().__init__ (num_embeddings + self .offset , embedding_dim )
130130
131- def forward (self , input_ids_shape : torch .Size , past_key_values_length : int = 0 ):
132- """`input_ids_shape` is expected to be [bsz x seqlen]."""
133- bsz , seq_len = input_ids_shape [:2 ]
131+ def forward (self , input_ids : torch .Tensor , past_key_values_length : int = 0 ):
132+ """`input_ids' shape is expected to be [bsz x seqlen]."""
133+
134+ bsz , seq_len = input_ids .shape [:2 ]
134135 positions = torch .arange (
135136 past_key_values_length , past_key_values_length + seq_len , dtype = torch .long , device = self .weight .device
136- )
137+ ).expand (bsz , - 1 )
138+
137139 return super ().forward (positions + self .offset )
138140
139141
@@ -788,17 +790,17 @@ def forward(
788790 if input_ids is not None and inputs_embeds is not None :
789791 raise ValueError ("You cannot specify both input_ids and inputs_embeds at the same time" )
790792 elif input_ids is not None :
791- input_shape = input_ids . size ()
792- input_ids = input_ids .view (- 1 , input_shape [- 1 ])
793+ input = input_ids
794+ input_ids = input_ids .view (- 1 , input_ids . shape [- 1 ])
793795 elif inputs_embeds is not None :
794- input_shape = inputs_embeds . size ()[: - 1 ]
796+ input = inputs_embeds [:, :, - 1 ]
795797 else :
796798 raise ValueError ("You have to specify either input_ids or inputs_embeds" )
797799
798800 if inputs_embeds is None :
799801 inputs_embeds = self .embed_tokens (input_ids ) * self .embed_scale
800802
801- embed_pos = self .embed_positions (input_shape )
803+ embed_pos = self .embed_positions (input )
802804
803805 hidden_states = inputs_embeds + embed_pos
804806 hidden_states = self .layernorm_embedding (hidden_states )
@@ -1015,18 +1017,20 @@ def forward(
10151017 if input_ids is not None and inputs_embeds is not None :
10161018 raise ValueError ("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" )
10171019 elif input_ids is not None :
1018- input_shape = input_ids .size ()
1020+ input = input_ids
1021+ input_shape = input .shape
10191022 input_ids = input_ids .view (- 1 , input_shape [- 1 ])
10201023 elif inputs_embeds is not None :
10211024 input_shape = inputs_embeds .size ()[:- 1 ]
1025+ input = inputs_embeds [:, :, - 1 ]
10221026 else :
10231027 raise ValueError ("You have to specify either decoder_input_ids or decoder_inputs_embeds" )
10241028
10251029 # past_key_values_length
10261030 past_key_values_length = past_key_values [0 ][0 ].shape [2 ] if past_key_values is not None else 0
10271031
10281032 if inputs_embeds is None :
1029- inputs_embeds = self .embed_tokens (input_ids ) * self .embed_scale
1033+ inputs_embeds = self .embed_tokens (input ) * self .embed_scale
10301034
10311035 attention_mask = self ._prepare_decoder_attention_mask (
10321036 attention_mask , input_shape , inputs_embeds , past_key_values_length
@@ -1038,7 +1042,7 @@ def forward(
10381042 encoder_attention_mask = _expand_mask (encoder_attention_mask , inputs_embeds .dtype , tgt_len = input_shape [- 1 ])
10391043
10401044 # embed positions
1041- positions = self .embed_positions (input_shape , past_key_values_length )
1045+ positions = self .embed_positions (input , past_key_values_length )
10421046
10431047 hidden_states = inputs_embeds + positions
10441048 hidden_states = self .layernorm_embedding (hidden_states )
0 commit comments