diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index ac3e1d77ecff..7a07495ba7e5 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -214,7 +214,7 @@ def forward( # This is analogous to the way that dropout layers scale down outputs during evaluation when not # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). if self.token_dropout: - embeddings.masked_fill_((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs src_lengths = attention_mask.sum(-1) mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths @@ -224,7 +224,7 @@ def forward( if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + embeddings = embeddings + position_embeddings if self.layer_norm is not None: embeddings = self.layer_norm(embeddings) @@ -399,7 +399,7 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states += input_tensor + hidden_states = hidden_states + input_tensor return hidden_states @@ -474,7 +474,7 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states += input_tensor + hidden_states = hidden_states + input_tensor return hidden_states @@ -633,7 +633,7 @@ def custom_forward(*inputs): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: