Skip to content

Commit faddb8d

Browse files
author
haileyschoelkopf
committed
fuse query_key_value layers
1 parent 961cee3 commit faddb8d

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

src/transformers/models/bloom/modeling_flax_bloom.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
_CONFIG_FOR_DOC = "BloomConfig"
4949
_TOKENIZER_FOR_DOC = "BloomTokenizer"
5050

51+
def attention_mask_func(args):
52+
# TODO: implement this helper fn. see pytorch impl. for reference
53+
raise NotImplementedError
5154

5255
BLOOM_START_DOCSTRING = r"""
5356
@@ -191,28 +194,22 @@ def setup(self):
191194

192195
self.scale_mask_softmax = FlaxBloomScaledSoftmax(
193196
self.config,
194-
attention_mask_func, # TODO: define this (in pytorch impl. it is a helper fn)
197+
attention_mask_func,
195198
self.attention_softmax_in_fp32,
196199
self.layer_number,
197200
)
198201

199202
dense = partial(
200203
nn.Dense,
201-
self.hidden_size,
202204
dtype=self.dtype,
203205
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
204206
)
205207

206-
# TODO: make this one dense layer that is split into 3 on forward named self.query_key_value
207-
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
208-
self.dense = dense()
208+
209+
self.query_key_value = dense(self.hidden_size * 3)
210+
self.dense = dense(self.hidden_size)
209211
self.attention_dropout = nn.Dropout(self.config.attention_dropout)
210212

211-
# TODO: check correctness of causal mask (unedited from gptneo causal mask)
212-
# TODO: how to deal with reliance on max_position_embeddings?
213-
max_position_embeddings = 2048
214-
self.causal_mask = make_causal_mask(jnp.ones((1, max_position_embeddings), dtype="bool"), dtype="bool")
215-
216213
def _split_heads(self, hidden_states):
217214
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
218215

@@ -263,27 +260,29 @@ def __call__(
263260
init_cache: bool = False,
264261
output_attentions: bool = False,
265262
):
266-
# TODO: this is still from the gpt-neo impl. needs to be rewritten
267-
# TODO: this needs checking for correctness of implementation.
268-
# TODO: modify so that it uses self.query_key_value?
269-
query = self.q_proj(hidden_states)
270-
key = self.k_proj(hidden_states)
271-
value = self.v_proj(hidden_states)
263+
# TODO: this module __call__ needs checking for correctness of implementation.
264+
265+
fused_qkv = self.query_key_value(hidden_states)
266+
267+
query, key, value = jnp.split(fused_qkv, 3, axis=-1)
272268

273269
query = self._split_heads(query)
274270
key = self._split_heads(key)
275271
value = self._split_heads(value)
276272

277273
query_length, key_length = query.shape[1], key.shape[1]
278274

275+
# TODO: check size of hidden_states to confirm this is the right dim for causal mask to use
276+
causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[0]), dtype="bool"), dtype="bool")
277+
279278
if self.has_variable("cache", "cached_key"):
280279
mask_shift = self.variables["cache"]["cache_index"]
281280
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
282281
causal_mask = lax.dynamic_slice(
283-
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
282+
causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
284283
)
285284
else:
286-
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
285+
causal_mask = causal_mask[:, :, :query_length, :key_length]
287286

288287
batch_size = hidden_states.shape[0]
289288
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
@@ -376,8 +375,6 @@ def setup(self):
376375

377376
self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
378377

379-
# TODO: should check if this line (n_head) can be removed. if so, can be removed in pytorch impl.
380-
self.n_heads = self.config.n_head
381378
self.self_attention = FlaxBloomAttention(self.config, layer_number=self.layer_number, dtype=self.dtype)
382379
self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
383380

@@ -421,6 +418,8 @@ def __call__(
421418

422419
outputs = attn_outputs[1:]
423420

421+
attention_output = attention_output + residual
422+
424423
post_layernorm = self.post_attention_layernorm(attention_output)
425424

426425
# set residual based on config
@@ -431,8 +430,9 @@ def __call__(
431430

432431
output = self.mlp(post_layernorm, residual, deterministic=deterministic)
433432

434-
# TODO: init_cache is separate from use_cache flag
435-
if init_cache:
433+
output = output + residual
434+
435+
if use_cache:
436436
outputs = (output,) + outputs
437437
else:
438438
outputs = (output,) + outputs[1:]

0 commit comments

Comments
 (0)