2121# TODO: check that code is jit-able
2222
2323from functools import partial
24- from typing import Any , Optional , Tuple
24+ from typing import Any , Callable , Optional , Tuple
2525
2626import flax .linen as nn
2727import jax
116116 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
117117"""
118118
119+ class FlaxBloomScaledSoftmax (nn .Module ):
120+ config : BloomConfig
121+ mask_func : Callable
122+ softmax_in_fp32 : bool
123+ scale : float = None
124+ """
125+ Scaled Softmax module. Also performs masking.
126+ Args:
127+ mask_func (`function`, *required*):
128+ mask function to be applied.
129+ softmax_in_fp32 (`bool`, *required*):
130+ if true, softmax in performed at fp32 precision.
131+ scale (`float`, *optional*):
132+ scaling factor used in input tensor scaling.
133+ """
134+
135+ def setup (self ):
136+ if not (self .scale is None or self .softmax_in_fp32 ):
137+ raise ValueError ("softmax should be in fp32 if scale is not `None`" )
138+
139+ def __call__ (self , input , mask , causal_mask ):
140+ input_dtype = input .dtype
141+ softmax_dtype = jnp .float32 if self .softmax_in_fp32 else input_dtype
142+
143+ if self .scale is not None :
144+ input = input * self .scale
145+
146+ if mask is not None :
147+ mask_output , padded_causal_mask = self .mask_func (input , mask , causal_mask )
148+ # TODO: ideally we could pass a dtype argument to nn.softmax like in PyTorch (see discussion on PR #17474 about fp32 softmax)
149+ mask_output .astype (softmax_dtype )
150+ probs = nn .softmax (mask_output , axis = - 1 ) * (~ padded_causal_mask )
151+ else :
152+ input .astype (softmax_dtype )
153+ probs = nn .softmax (input , axis = - 1 )
154+
155+ if input_dtype != softmax_dtype :
156+ probs = probs .astype (input_dtype )
157+
158+ return probs
159+
119160
120161class FlaxBloomAttention (nn .Module ):
121162 config : BloomConfig
@@ -133,7 +174,6 @@ def setup(self):
133174 self .split_size = self .hidden_size
134175 # TODO: deal with softmax
135176 self .attention_softmax_in_fp32 = self .config .attention_softmax_in_fp32
136- self .masked_softmax_fusion = self .config .masked_softmax_fusion
137177 # TODO: deal with hidden dropout
138178 self .hidden_dropout = self .config .hidden_dropout
139179
@@ -149,13 +189,12 @@ def setup(self):
149189
150190 self .attn_dropout = nn .Dropout (self .config .attention_dropout )
151191
152- # Scaled Softmax TODO: change this to something implemented in jax (maybe implement in __call__ for attn module?)
153- # self.scale_mask_softmax = BloomScaledSoftmax(
154- # self.masked_softmax_fusion,
155- # attention_mask_func,
156- # self.attention_softmax_in_fp32,
157- # self.layer_number,
158- # )
192+ self .scale_mask_softmax = FlaxBloomScaledSoftmax (
193+ self .config ,
194+ attention_mask_func , # TODO: define this (in pytorch impl. it is a helper fn)
195+ self .attention_softmax_in_fp32 ,
196+ self .layer_number ,
197+ )
159198
160199 dense = partial (
161200 nn .Dense ,
@@ -224,6 +263,7 @@ def __call__(
224263 init_cache : bool = False ,
225264 output_attentions : bool = False ,
226265 ):
266+ # TODO: this is still from the gpt-neo impl. needs to be rewritten
227267 # TODO: this needs checking for correctness of implementation.
228268 # TODO: modify so that it uses self.query_key_value?
229269 query = self .q_proj (hidden_states )
0 commit comments