@@ -1785,6 +1785,11 @@ def __call__(
17851785 key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
17861786 value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
17871787
1788+ if attn .norm_q is not None :
1789+ query = attn .norm_q (query )
1790+ if attn .norm_k is not None :
1791+ key = attn .norm_k (key )
1792+
17881793 # the output of sdp = (batch, num_heads, seq_len, head_dim)
17891794 # TODO: add support for attn.scale when we move to Torch 2.1
17901795 hidden_states = F .scaled_dot_product_attention (
@@ -2314,6 +2319,11 @@ def __call__(
23142319 key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
23152320 value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
23162321
2322+ if attn .norm_q is not None :
2323+ query = attn .norm_q (query )
2324+ if attn .norm_k is not None :
2325+ key = attn .norm_k (key )
2326+
23172327 # the output of sdp = (batch, num_heads, seq_len, head_dim)
23182328 # TODO: add support for attn.scale when we move to Torch 2.1
23192329 hidden_states = F .scaled_dot_product_attention (
0 commit comments