Skip to content

Commit 961cee3

Browse files
author
haileyschoelkopf
committed
add FlaxBloomScaledSoftmax
1 parent 3fc824b commit 961cee3

File tree

1 file changed

+49
-9
lines changed

1 file changed

+49
-9
lines changed

src/transformers/models/bloom/modeling_flax_bloom.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# TODO: check that code is jit-able
2222

2323
from functools import partial
24-
from typing import Any, Optional, Tuple
24+
from typing import Any, Callable, Optional, Tuple
2525

2626
import flax.linen as nn
2727
import jax
@@ -116,6 +116,47 @@
116116
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
117117
"""
118118

119+
class FlaxBloomScaledSoftmax(nn.Module):
120+
config: BloomConfig
121+
mask_func: Callable
122+
softmax_in_fp32: bool
123+
scale: float = None
124+
"""
125+
Scaled Softmax module. Also performs masking.
126+
Args:
127+
mask_func (`function`, *required*):
128+
mask function to be applied.
129+
softmax_in_fp32 (`bool`, *required*):
130+
if true, softmax in performed at fp32 precision.
131+
scale (`float`, *optional*):
132+
scaling factor used in input tensor scaling.
133+
"""
134+
135+
def setup(self):
136+
if not (self.scale is None or self.softmax_in_fp32):
137+
raise ValueError("softmax should be in fp32 if scale is not `None`")
138+
139+
def __call__(self, input, mask, causal_mask):
140+
input_dtype = input.dtype
141+
softmax_dtype = jnp.float32 if self.softmax_in_fp32 else input_dtype
142+
143+
if self.scale is not None:
144+
input = input * self.scale
145+
146+
if mask is not None:
147+
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
148+
# TODO: ideally we could pass a dtype argument to nn.softmax like in PyTorch (see discussion on PR #17474 about fp32 softmax)
149+
mask_output.astype(softmax_dtype)
150+
probs = nn.softmax(mask_output, axis=-1) * (~padded_causal_mask)
151+
else:
152+
input.astype(softmax_dtype)
153+
probs = nn.softmax(input, axis=-1)
154+
155+
if input_dtype != softmax_dtype:
156+
probs = probs.astype(input_dtype)
157+
158+
return probs
159+
119160

120161
class FlaxBloomAttention(nn.Module):
121162
config: BloomConfig
@@ -133,7 +174,6 @@ def setup(self):
133174
self.split_size = self.hidden_size
134175
# TODO: deal with softmax
135176
self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32
136-
self.masked_softmax_fusion = self.config.masked_softmax_fusion
137177
# TODO: deal with hidden dropout
138178
self.hidden_dropout = self.config.hidden_dropout
139179

@@ -149,13 +189,12 @@ def setup(self):
149189

150190
self.attn_dropout = nn.Dropout(self.config.attention_dropout)
151191

152-
# Scaled Softmax TODO: change this to something implemented in jax (maybe implement in __call__ for attn module?)
153-
# self.scale_mask_softmax = BloomScaledSoftmax(
154-
# self.masked_softmax_fusion,
155-
# attention_mask_func,
156-
# self.attention_softmax_in_fp32,
157-
# self.layer_number,
158-
# )
192+
self.scale_mask_softmax = FlaxBloomScaledSoftmax(
193+
self.config,
194+
attention_mask_func, # TODO: define this (in pytorch impl. it is a helper fn)
195+
self.attention_softmax_in_fp32,
196+
self.layer_number,
197+
)
159198

160199
dense = partial(
161200
nn.Dense,
@@ -224,6 +263,7 @@ def __call__(
224263
init_cache: bool = False,
225264
output_attentions: bool = False,
226265
):
266+
# TODO: this is still from the gpt-neo impl. needs to be rewritten
227267
# TODO: this needs checking for correctness of implementation.
228268
# TODO: modify so that it uses self.query_key_value?
229269
query = self.q_proj(hidden_states)

0 commit comments

Comments
 (0)