4848_CONFIG_FOR_DOC = "BloomConfig"
4949_TOKENIZER_FOR_DOC = "BloomTokenizer"
5050
51+ def attention_mask_func (args ):
52+ # TODO: implement this helper fn. see pytorch impl. for reference
53+ raise NotImplementedError
5154
5255BLOOM_START_DOCSTRING = r"""
5356
@@ -191,28 +194,22 @@ def setup(self):
191194
192195 self .scale_mask_softmax = FlaxBloomScaledSoftmax (
193196 self .config ,
194- attention_mask_func , # TODO: define this (in pytorch impl. it is a helper fn)
197+ attention_mask_func ,
195198 self .attention_softmax_in_fp32 ,
196199 self .layer_number ,
197200 )
198201
199202 dense = partial (
200203 nn .Dense ,
201- self .hidden_size ,
202204 dtype = self .dtype ,
203205 kernel_init = jax .nn .initializers .normal (self .config .initializer_range ),
204206 )
205207
206- # TODO: make this one dense layer that is split into 3 on forward named self.query_key_value
207- self .q_proj , self . k_proj , self . v_proj = dense (), dense (), dense ( )
208- self .dense = dense ()
208+
209+ self .query_key_value = dense (self . hidden_size * 3 )
210+ self .dense = dense (self . hidden_size )
209211 self .attention_dropout = nn .Dropout (self .config .attention_dropout )
210212
211- # TODO: check correctness of causal mask (unedited from gptneo causal mask)
212- # TODO: how to deal with reliance on max_position_embeddings?
213- max_position_embeddings = 2048
214- self .causal_mask = make_causal_mask (jnp .ones ((1 , max_position_embeddings ), dtype = "bool" ), dtype = "bool" )
215-
216213 def _split_heads (self , hidden_states ):
217214 return hidden_states .reshape (hidden_states .shape [:2 ] + (self .num_heads , self .head_dim ))
218215
@@ -263,27 +260,29 @@ def __call__(
263260 init_cache : bool = False ,
264261 output_attentions : bool = False ,
265262 ):
266- # TODO: this is still from the gpt-neo impl. needs to be rewritten
267- # TODO: this needs checking for correctness of implementation.
268- # TODO: modify so that it uses self.query_key_value?
269- query = self .q_proj (hidden_states )
270- key = self .k_proj (hidden_states )
271- value = self .v_proj (hidden_states )
263+ # TODO: this module __call__ needs checking for correctness of implementation.
264+
265+ fused_qkv = self .query_key_value (hidden_states )
266+
267+ query , key , value = jnp .split (fused_qkv , 3 , axis = - 1 )
272268
273269 query = self ._split_heads (query )
274270 key = self ._split_heads (key )
275271 value = self ._split_heads (value )
276272
277273 query_length , key_length = query .shape [1 ], key .shape [1 ]
278274
275+ # TODO: check size of hidden_states to confirm this is the right dim for causal mask to use
276+ causal_mask = make_causal_mask (jnp .ones ((1 , hidden_states .shape [0 ]), dtype = "bool" ), dtype = "bool" )
277+
279278 if self .has_variable ("cache" , "cached_key" ):
280279 mask_shift = self .variables ["cache" ]["cache_index" ]
281280 max_decoder_length = self .variables ["cache" ]["cached_key" ].shape [1 ]
282281 causal_mask = lax .dynamic_slice (
283- self . causal_mask , (0 , 0 , mask_shift , 0 ), (1 , 1 , query_length , max_decoder_length )
282+ causal_mask , (0 , 0 , mask_shift , 0 ), (1 , 1 , query_length , max_decoder_length )
284283 )
285284 else :
286- causal_mask = self . causal_mask [:, :, :query_length , :key_length ]
285+ causal_mask = causal_mask [:, :, :query_length , :key_length ]
287286
288287 batch_size = hidden_states .shape [0 ]
289288 causal_mask = jnp .broadcast_to (causal_mask , (batch_size ,) + causal_mask .shape [1 :])
@@ -376,8 +375,6 @@ def setup(self):
376375
377376 self .input_layernorm = nn .LayerNorm (epsilon = self .config .layer_norm_epsilon , dtype = self .dtype )
378377
379- # TODO: should check if this line (n_head) can be removed. if so, can be removed in pytorch impl.
380- self .n_heads = self .config .n_head
381378 self .self_attention = FlaxBloomAttention (self .config , layer_number = self .layer_number , dtype = self .dtype )
382379 self .post_attention_layernorm = nn .LayerNorm (epsilon = self .config .layer_norm_epsilon , dtype = self .dtype )
383380
@@ -421,6 +418,8 @@ def __call__(
421418
422419 outputs = attn_outputs [1 :]
423420
421+ attention_output = attention_output + residual
422+
424423 post_layernorm = self .post_attention_layernorm (attention_output )
425424
426425 # set residual based on config
@@ -431,8 +430,9 @@ def __call__(
431430
432431 output = self .mlp (post_layernorm , residual , deterministic = deterministic )
433432
434- # TODO: init_cache is separate from use_cache flag
435- if init_cache :
433+ output = output + residual
434+
435+ if use_cache :
436436 outputs = (output ,) + outputs
437437 else :
438438 outputs = (output ,) + outputs [1 :]
0 commit comments