Skip to content

Commit 689f599

Browse files
fix use cache (#3)
1 parent 7800457 commit 689f599

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/models/opt/modeling_opt.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def forward(
317317
# for the decoder
318318
is_cross_attention = key_value_states is not None
319319

320-
bsz, tgt_len, _ = hidden_states.size()
320+
bsz, _, _ = hidden_states.size()
321321

322322
# get query proj
323323
query_states = self.q_proj(hidden_states)
@@ -351,13 +351,15 @@ def forward(
351351
# if encoder bi-directional self-attention `past_key_value` is always `None`
352352
past_key_value = (key_states, value_states)
353353

354+
query_length = query_states.shape[1]
355+
tgt_len = key_states.shape[-2]
356+
354357
# Flash attention requires the input to have the shape
355358
# batch_size x seq_length x head_dim x hidden_dim
356-
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
359+
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
357360
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
358361
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
359362

360-
_, query_length, _, _ = query_states.shape
361363

362364
attn_dropout = self.dropout if self.training else 0.0
363365

0 commit comments

Comments
 (0)