diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 94fb3ac1db91..c89fc46113f7 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -104,7 +104,10 @@ def __init__(self, nx, n_ctx, config, scale=False): n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % config.n_head == 0 - self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.register_buffer( + "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) self.n_head = config.n_head self.split_size = n_state self.scale = scale @@ -142,8 +145,8 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None): if self.scale: w = w / math.sqrt(v.size(-1)) nd, ns = w.size(-2), w.size(-1) - b = self.bias[:, :, ns - nd : ns, :ns] - w = w * b - 1e4 * (1 - b) + mask = self.bias[:, :, ns - nd : ns, :ns] + w = torch.where(mask, w, self.masked_bias) if attention_mask is not None: # Apply the attention mask